|
@@ -1,5 +1,6 @@
|
|
|
import queue
|
|
|
import time
|
|
|
+from enum import Enum
|
|
|
from typing import Generator, Any
|
|
|
|
|
|
from sqlalchemy.orm import DeclarativeMeta
|
|
@@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client
|
|
|
from models.model import MessageAgentThought
|
|
|
|
|
|
|
|
|
+class PublishFrom(Enum):
|
|
|
+ APPLICATION_MANAGER = 1
|
|
|
+ TASK_PIPELINE = 2
|
|
|
+
|
|
|
+
|
|
|
class ApplicationQueueManager:
|
|
|
def __init__(self, task_id: str,
|
|
|
user_id: str,
|
|
@@ -61,11 +67,14 @@ class ApplicationQueueManager:
|
|
|
if elapsed_time >= listen_timeout or self._is_stopped():
|
|
|
# publish two messages to make sure the client can receive the stop signal
|
|
|
# and stop listening after the stop signal processed
|
|
|
- self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
|
|
|
+ self.publish(
|
|
|
+ QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
|
|
+ PublishFrom.TASK_PIPELINE
|
|
|
+ )
|
|
|
self.stop_listen()
|
|
|
|
|
|
if elapsed_time // 10 > last_ping_time:
|
|
|
- self.publish(QueuePingEvent())
|
|
|
+ self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
|
|
last_ping_time = elapsed_time // 10
|
|
|
|
|
|
def stop_listen(self) -> None:
|
|
@@ -75,76 +84,83 @@ class ApplicationQueueManager:
|
|
|
"""
|
|
|
self._q.put(None)
|
|
|
|
|
|
- def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
|
|
|
+ def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish chunk message to channel
|
|
|
|
|
|
:param chunk: chunk
|
|
|
+ :param pub_from: publish from
|
|
|
:return:
|
|
|
"""
|
|
|
self.publish(QueueMessageEvent(
|
|
|
chunk=chunk
|
|
|
- ))
|
|
|
+ ), pub_from)
|
|
|
|
|
|
- def publish_message_replace(self, text: str) -> None:
|
|
|
+ def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish message replace
|
|
|
:param text: text
|
|
|
+ :param pub_from: publish from
|
|
|
:return:
|
|
|
"""
|
|
|
self.publish(QueueMessageReplaceEvent(
|
|
|
text=text
|
|
|
- ))
|
|
|
+ ), pub_from)
|
|
|
|
|
|
- def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
|
|
|
+ def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish retriever resources
|
|
|
:return:
|
|
|
"""
|
|
|
- self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
|
|
|
+ self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
|
|
|
|
|
|
- def publish_annotation_reply(self, message_annotation_id: str) -> None:
|
|
|
+ def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish annotation reply
|
|
|
:param message_annotation_id: message annotation id
|
|
|
+ :param pub_from: publish from
|
|
|
:return:
|
|
|
"""
|
|
|
- self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
|
|
|
+ self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
|
|
|
|
|
|
- def publish_message_end(self, llm_result: LLMResult) -> None:
|
|
|
+ def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish message end
|
|
|
:param llm_result: llm result
|
|
|
+ :param pub_from: publish from
|
|
|
:return:
|
|
|
"""
|
|
|
- self.publish(QueueMessageEndEvent(llm_result=llm_result))
|
|
|
+ self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
|
|
|
self.stop_listen()
|
|
|
|
|
|
- def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
|
|
+ def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish agent thought
|
|
|
:param message_agent_thought: message agent thought
|
|
|
+ :param pub_from: publish from
|
|
|
:return:
|
|
|
"""
|
|
|
self.publish(QueueAgentThoughtEvent(
|
|
|
agent_thought_id=message_agent_thought.id
|
|
|
- ))
|
|
|
+ ), pub_from)
|
|
|
|
|
|
- def publish_error(self, e) -> None:
|
|
|
+ def publish_error(self, e, pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish error
|
|
|
:param e: error
|
|
|
+ :param pub_from: publish from
|
|
|
:return:
|
|
|
"""
|
|
|
self.publish(QueueErrorEvent(
|
|
|
error=e
|
|
|
- ))
|
|
|
+ ), pub_from)
|
|
|
self.stop_listen()
|
|
|
|
|
|
- def publish(self, event: AppQueueEvent) -> None:
|
|
|
+ def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
|
|
"""
|
|
|
Publish event to queue
|
|
|
:param event:
|
|
|
+ :param pub_from:
|
|
|
:return:
|
|
|
"""
|
|
|
self._check_for_sqlalchemy_models(event.dict())
|
|
@@ -162,6 +178,9 @@ class ApplicationQueueManager:
|
|
|
if isinstance(event, QueueStopEvent):
|
|
|
self.stop_listen()
|
|
|
|
|
|
+ if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
|
|
+ raise ConversationTaskStoppedException()
|
|
|
+
|
|
|
@classmethod
|
|
|
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
|
|
"""
|
|
@@ -187,7 +206,6 @@ class ApplicationQueueManager:
|
|
|
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
|
|
result = redis_client.get(stopped_cache_key)
|
|
|
if result is not None:
|
|
|
- redis_client.delete(stopped_cache_key)
|
|
|
return True
|
|
|
|
|
|
return False
|