Selaa lähdekoodia

fix: enhance type hints and improve audio message handling in TTS pub… (#11947)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 kuukautta sitten
vanhempi
commit
dd0e81d094

+ 19 - 13
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py

@@ -4,14 +4,17 @@ import logging
 import queue
 import re
 import threading
+from collections.abc import Iterable
 
 from core.app.entities.queue_entities import (
+    MessageQueueMessage,
     QueueAgentMessageEvent,
     QueueLLMChunkEvent,
     QueueNodeSucceededEvent,
     QueueTextChunkEvent,
+    WorkflowQueueMessage,
 )
-from core.model_manager import ModelManager
+from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 
 
@@ -21,7 +24,7 @@ class AudioTrunk:
         self.status = status
 
 
-def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
+def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
     if not text_content or text_content.isspace():
         return
     return model_instance.invoke_tts(
@@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
     )
 
 
-def _process_future(future_queue, audio_queue):
+def _process_future(
+    future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
+    audio_queue: queue.Queue[AudioTrunk],
+):
     while True:
         try:
             future = future_queue.get()
             if future is None:
                 break
-            for audio in future.result():
+            invoke_result = future.result()
+            if not invoke_result:
+                continue
+            for audio in invoke_result:
                 audio_base64 = base64.b64encode(bytes(audio))
                 audio_queue.put(AudioTrunk("responding", audio=audio_base64))
         except Exception as e:
@@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
         self.logger = logging.getLogger(__name__)
         self.tenant_id = tenant_id
         self.msg_text = ""
-        self._audio_queue = queue.Queue()
-        self._msg_queue = queue.Queue()
+        self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
+        self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
         self.match = re.compile(r"[。.!?]")
         self.model_manager = ModelManager()
         self.model_instance = self.model_manager.get_default_model_instance(
@@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
         self._runtime_thread = threading.Thread(target=self._runtime).start()
         self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
 
-    def publish(self, message):
-        try:
-            self._msg_queue.put(message)
-        except Exception as e:
-            self.logger.warning(e)
+    def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
+        self._msg_queue.put(message)
 
     def _runtime(self):
-        future_queue = queue.Queue()
+        future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
         threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
         while True:
             try:
@@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
                 break
         future_queue.put(None)
 
-    def check_and_get_audio(self) -> AudioTrunk | None:
+    def check_and_get_audio(self):
         try:
             if self._last_audio_event and self._last_audio_event.status == "finish":
                 if self.executor:

+ 5 - 5
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 stream_response=stream_response,
             )
 
-    def _listen_audio_msg(self, publisher, task_id: str):
+    def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
         if not publisher:
             return None
-        audio_msg: AudioTrunk = publisher.check_and_get_audio()
-        if audio_msg and audio_msg.status != "finish":
+        audio_msg = publisher.check_and_get_audio()
+        if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
             return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
         return None
 
@@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
         for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
             while True:
-                audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
+                audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
                 if audio_response:
                     yield audio_response
                 else:
@@ -511,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
                 # only publish tts message at text chunk streaming
                 if tts_publisher:
-                    tts_publisher.publish(message=queue_message)
+                    tts_publisher.publish(queue_message)
 
                 self._task_state.answer += delta_text
                 yield self._message_to_stream_response(

+ 4 - 3
api/core/app/apps/base_app_queue_manager.py

@@ -1,7 +1,6 @@
 import queue
 import time
 from abc import abstractmethod
-from collections.abc import Generator
 from enum import Enum
 from typing import Any
 
@@ -11,9 +10,11 @@ from configs import dify_config
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.queue_entities import (
     AppQueueEvent,
+    MessageQueueMessage,
     QueueErrorEvent,
     QueuePingEvent,
     QueueStopEvent,
+    WorkflowQueueMessage,
 )
 from extensions.ext_redis import redis_client
 
@@ -37,11 +38,11 @@ class AppQueueManager:
             AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
         )
 
-        q = queue.Queue()
+        q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
 
         self._q = q
 
-    def listen(self) -> Generator:
+    def listen(self):
         """
         Listen to queue
         :return:

+ 5 - 5
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
             yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
 
-    def _listen_audio_msg(self, publisher, task_id: str):
+    def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
         if not publisher:
             return None
-        audio_msg: AudioTrunk = publisher.check_and_get_audio()
-        if audio_msg and audio_msg.status != "finish":
+        audio_msg = publisher.check_and_get_audio()
+        if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
             return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
         return None
 
@@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
         for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
             while True:
-                audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
+                audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
                 if audio_response:
                     yield audio_response
                 else:
@@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
                 # only publish tts message at text chunk streaming
                 if tts_publisher:
-                    tts_publisher.publish(message=queue_message)
+                    tts_publisher.publish(queue_message)
 
                 self._task_state.answer += delta_text
                 yield self._text_chunk_to_stream_response(

+ 3 - 3
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

@@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                     stream_response=stream_response,
                 )
 
-    def _listen_audio_msg(self, publisher, task_id: str):
+    def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
         if publisher is None:
             return None
-        audio_msg: AudioTrunk = publisher.check_and_get_audio()
-        if audio_msg and audio_msg.status != "finish":
+        audio_msg = publisher.check_and_get_audio()
+        if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
             # audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
             return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
         return None