Преглед на файлове

fix: import wrong user (#32)

John Wang преди 1 година
родител
ревизия
f5b2271c8c
променени са 1 файла, в които са добавени 5 реда и са изтрити 7 реда
  1. 5 7
      api/core/conversation_message_task.py

+ 5 - 7
api/core/conversation_message_task.py

@@ -2,8 +2,6 @@ import decimal
 import json
 from typing import Optional, Union
 
-from gunicorn.config import User
-
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.llm_message import LLMMessage
@@ -269,7 +267,7 @@ class ConversationMessageTask:
 
 
 class PubHandler:
-    def __init__(self, user: Union[Account | User], task_id: str,
+    def __init__(self, user: Union[Account | EndUser], task_id: str,
                  message: Message, conversation: Conversation,
                  chain_pub: bool = False, agent_thought_pub: bool = False):
         self._channel = PubHandler.generate_channel_name(user, task_id)
@@ -282,12 +280,12 @@ class PubHandler:
         self._agent_thought_pub = agent_thought_pub
 
     @classmethod
-    def generate_channel_name(cls, user: Union[Account | User], task_id: str):
+    def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
         user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
         return "generate_result:{}-{}".format(user_str, task_id)
 
     @classmethod
-    def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str):
+    def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
         user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
         return "generate_result_stopped:{}-{}".format(user_str, task_id)
 
@@ -366,7 +364,7 @@ class PubHandler:
         redis_client.publish(self._channel, json.dumps(content))
 
     @classmethod
-    def pub_error(cls, user: Union[Account | User], task_id: str, e):
+    def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
         content = {
             'error': type(e).__name__,
             'description': e.description if getattr(e, 'description', None) is not None else str(e)
@@ -379,7 +377,7 @@ class PubHandler:
         return redis_client.get(self._stopped_cache_key) is not None
 
     @classmethod
-    def stop(cls, user: Union[Account | User], task_id: str):
+    def stop(cls, user: Union[Account | EndUser], task_id: str):
         stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
         redis_client.setex(stopped_cache_key, 600, 1)