Forráskód Böngészése

fix: generate not stop when pressing stop link (#1961)

takatost 1 éve
szülő
commit
0c746f5c5a

+ 9 - 6
api/core/app_runner/app_runner.py

@@ -1,7 +1,7 @@
 import time
 from typing import cast, Optional, List, Tuple, Generator, Union
 
-from core.application_queue_manager import ApplicationQueueManager
+from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
 from core.file.file_obj import FileObj
 from core.memory.token_buffer_memory import TokenBufferMemory
@@ -183,7 +183,7 @@ class AppRunner:
                         index=index,
                         message=AssistantPromptMessage(content=token)
                     )
-                ))
+                ), PublishFrom.APPLICATION_MANAGER)
                 index += 1
                 time.sleep(0.01)
 
@@ -193,7 +193,8 @@ class AppRunner:
                 prompt_messages=prompt_messages,
                 message=AssistantPromptMessage(content=text),
                 usage=usage if usage else LLMUsage.empty_usage()
-            )
+            ),
+            pub_from=PublishFrom.APPLICATION_MANAGER
         )
 
     def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
@@ -226,7 +227,8 @@ class AppRunner:
         :return:
         """
         queue_manager.publish_message_end(
-            llm_result=invoke_result
+            llm_result=invoke_result,
+            pub_from=PublishFrom.APPLICATION_MANAGER
         )
 
     def _handle_invoke_result_stream(self, invoke_result: Generator,
@@ -242,7 +244,7 @@ class AppRunner:
         text = ''
         usage = None
         for result in invoke_result:
-            queue_manager.publish_chunk_message(result)
+            queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
 
             text += result.delta.message.content
 
@@ -263,5 +265,6 @@ class AppRunner:
         )
 
         queue_manager.publish_message_end(
-            llm_result=llm_result
+            llm_result=llm_result,
+            pub_from=PublishFrom.APPLICATION_MANAGER
         )

+ 3 - 2
api/core/app_runner/basic_app_runner.py

@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
     AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
-from core.application_queue_manager import ApplicationQueueManager
+from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.features.annotation_reply import AnnotationReplyFeature
 from core.features.dataset_retrieval import DatasetRetrievalFeature
 from core.features.external_data_fetch import ExternalDataFetchFeature
@@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner):
 
             if annotation_reply:
                 queue_manager.publish_annotation_reply(
-                    message_annotation_id=annotation_reply.id
+                    message_annotation_id=annotation_reply.id,
+                    pub_from=PublishFrom.APPLICATION_MANAGER
                 )
                 self.direct_output(
                     queue_manager=queue_manager,

+ 6 - 3
api/core/app_runner/generate_task_pipeline.py

@@ -7,7 +7,7 @@ from pydantic import BaseModel
 
 from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
 from core.entities.application_entities import ApplicationGenerateEntity
-from core.application_queue_manager import ApplicationQueueManager
+from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
     QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
     AnnotationReplyEvent
@@ -312,8 +312,11 @@ class GenerateTaskPipeline:
                                 index=0,
                                 message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
                             )
-                        ))
-                        self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
+                        ), PublishFrom.TASK_PIPELINE)
+                        self._queue_manager.publish(
+                            QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
+                            PublishFrom.TASK_PIPELINE
+                        )
                         continue
                     else:
                         self._output_moderation_handler.append_new_token(delta_text)

+ 2 - 1
api/core/app_runner/moderation_handler.py

@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict
 from flask import current_app, Flask
 from pydantic import BaseModel
 
+from core.application_queue_manager import PublishFrom
 from core.moderation.base import ModerationAction, ModerationOutputsResult
 from core.moderation.factory import ModerationFactory
 
@@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel):
             final_output = result.text
 
         if public_event:
-            self.on_message_replace_func(final_output)
+            self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
 
         return final_output
 

+ 8 - 5
api/core/application_manager.py

@@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.prompt.prompt_template import PromptTemplateParser
 from core.provider_manager import ProviderManager
-from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
+from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
 from extensions.ext_database import db
 from models.account import Account
 from models.model import EndUser, Conversation, Message, MessageFile, App
@@ -169,15 +169,18 @@ class ApplicationManager:
             except ConversationTaskStoppedException:
                 pass
             except InvokeAuthorizationError:
-                queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
+                queue_manager.publish_error(
+                    InvokeAuthorizationError('Incorrect API key provided'),
+                    PublishFrom.APPLICATION_MANAGER
+                )
             except ValidationError as e:
                 logger.exception("Validation Error when generating")
-                queue_manager.publish_error(e)
+                queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except (ValueError, InvokeError) as e:
-                queue_manager.publish_error(e)
+                queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except Exception as e:
                 logger.exception("Unknown Error when generating")
-                queue_manager.publish_error(e)
+                queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             finally:
                 db.session.remove()
 

+ 36 - 18
api/core/application_queue_manager.py

@@ -1,5 +1,6 @@
 import queue
 import time
+from enum import Enum
 from typing import Generator, Any
 
 from sqlalchemy.orm import DeclarativeMeta
@@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client
 from models.model import MessageAgentThought
 
 
+class PublishFrom(Enum):
+    APPLICATION_MANAGER = 1
+    TASK_PIPELINE = 2
+
+
 class ApplicationQueueManager:
     def __init__(self, task_id: str,
                  user_id: str,
@@ -61,11 +67,14 @@ class ApplicationQueueManager:
                 if elapsed_time >= listen_timeout or self._is_stopped():
                     # publish two messages to make sure the client can receive the stop signal
                     # and stop listening after the stop signal processed
-                    self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
+                    self.publish(
+                        QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
+                        PublishFrom.TASK_PIPELINE
+                    )
                     self.stop_listen()
 
                 if elapsed_time // 10 > last_ping_time:
-                    self.publish(QueuePingEvent())
+                    self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
                     last_ping_time = elapsed_time // 10
 
     def stop_listen(self) -> None:
@@ -75,76 +84,83 @@ class ApplicationQueueManager:
         """
         self._q.put(None)
 
-    def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
+    def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
         """
         Publish chunk message to channel
 
         :param chunk: chunk
+        :param pub_from: publish from
         :return:
         """
         self.publish(QueueMessageEvent(
             chunk=chunk
-        ))
+        ), pub_from)
 
-    def publish_message_replace(self, text: str) -> None:
+    def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
         """
         Publish message replace
         :param text: text
+        :param pub_from: publish from
         :return:
         """
         self.publish(QueueMessageReplaceEvent(
             text=text
-        ))
+        ), pub_from)
 
-    def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
+    def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
         """
         Publish retriever resources
         :return:
         """
-        self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
+        self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
 
-    def publish_annotation_reply(self, message_annotation_id: str) -> None:
+    def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
         """
         Publish annotation reply
         :param message_annotation_id: message annotation id
+        :param pub_from: publish from
         :return:
         """
-        self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
+        self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
 
-    def publish_message_end(self, llm_result: LLMResult) -> None:
+    def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
         """
         Publish message end
         :param llm_result: llm result
+        :param pub_from: publish from
         :return:
         """
-        self.publish(QueueMessageEndEvent(llm_result=llm_result))
+        self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
         self.stop_listen()
 
-    def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
+    def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
         """
         Publish agent thought
         :param message_agent_thought: message agent thought
+        :param pub_from: publish from
         :return:
         """
         self.publish(QueueAgentThoughtEvent(
             agent_thought_id=message_agent_thought.id
-        ))
+        ), pub_from)
 
-    def publish_error(self, e) -> None:
+    def publish_error(self, e, pub_from: PublishFrom) -> None:
         """
         Publish error
         :param e: error
+        :param pub_from: publish from
         :return:
         """
         self.publish(QueueErrorEvent(
             error=e
-        ))
+        ), pub_from)
         self.stop_listen()
 
-    def publish(self, event: AppQueueEvent) -> None:
+    def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
         """
         Publish event to queue
         :param event:
+        :param pub_from:
         :return:
         """
         self._check_for_sqlalchemy_models(event.dict())
@@ -162,6 +178,9 @@ class ApplicationQueueManager:
         if isinstance(event, QueueStopEvent):
             self.stop_listen()
 
+        if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
+            raise ConversationTaskStoppedException()
+
     @classmethod
     def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
         """
@@ -187,7 +206,6 @@ class ApplicationQueueManager:
         stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
         result = redis_client.get(stopped_cache_key)
         if result is not None:
-            redis_client.delete(stopped_cache_key)
             return True
 
         return False

+ 2 - 2
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
 
-from core.application_queue_manager import ApplicationQueueManager
+from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.entities.application_entities import ModelConfigEntity
 from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
@@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         db.session.add(message_agent_thought)
         db.session.commit()
 
-        self.queue_manager.publish_agent_thought(message_agent_thought)
+        self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER)
 
         return message_agent_thought
 

+ 2 - 2
api/core/callback_handler/index_tool_callback_handler.py

@@ -2,7 +2,7 @@ from typing import List, Union
 
 from langchain.schema import Document
 
-from core.application_queue_manager import ApplicationQueueManager
+from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.entities.application_entities import InvokeFrom
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, DatasetQuery
@@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler:
                 db.session.add(dataset_retriever_resource)
                 db.session.commit()
 
-        self._queue_manager.publish_retriever_resources(resource)
+        self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)