소스 검색

fix: workflow trace user_id error (#6932)

Joe 8 달 전
부모
커밋
26e46d365c

+ 2 - 1
api/core/app/apps/agent_chat/app_generator.py

@@ -110,7 +110,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
         )
 
         # get tracing instance
-        trace_manager = TraceQueueManager(app_model.id)
+        user_id = user.id if isinstance(user, Account) else user.session_id
+        trace_manager = TraceQueueManager(app_model.id, user_id)
 
         # init application generate entity
         application_generate_entity = AgentChatAppGenerateEntity(

+ 2 - 1
api/core/app/apps/workflow/app_generator.py

@@ -74,7 +74,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
         )
 
         # get tracing instance
-        trace_manager = TraceQueueManager(app_model.id)
+        user_id = user.id if isinstance(user, Account) else user.session_id
+        trace_manager = TraceQueueManager(app_model.id, user_id)
 
         # init application generate entity
         application_generate_entity = WorkflowAppGenerateEntity(

+ 2 - 0
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -131,6 +131,7 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
                     TraceTaskName.WORKFLOW_TRACE,
                     workflow_run=workflow_run,
                     conversation_id=conversation_id,
+                    user_id=trace_manager.user_id,
                 )
             )
 
@@ -173,6 +174,7 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
                     TraceTaskName.WORKFLOW_TRACE,
                     workflow_run=workflow_run,
                     conversation_id=conversation_id,
+                    user_id=trace_manager.user_id,
                 )
             )
 

+ 3 - 2
api/core/ops/langfuse_trace/langfuse_trace.py

@@ -65,12 +65,13 @@ class LangFuseDataTrace(BaseTraceInstance):
 
     def workflow_trace(self, trace_info: WorkflowTraceInfo):
         trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
+        user_id = trace_info.metadata.get("user_id")
         if trace_info.message_id:
             trace_id = trace_info.message_id
             name = f"message_{trace_info.message_id}"
             trace_data = LangfuseTrace(
                 id=trace_info.message_id,
-                user_id=trace_info.tenant_id,
+                user_id=user_id,
                 name=name,
                 input=trace_info.workflow_run_inputs,
                 output=trace_info.workflow_run_outputs,
@@ -95,7 +96,7 @@ class LangFuseDataTrace(BaseTraceInstance):
         else:
             trace_data = LangfuseTrace(
                 id=trace_id,
-                user_id=trace_info.tenant_id,
+                user_id=user_id,
                 name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
                 input=trace_info.workflow_run_inputs,
                 output=trace_info.workflow_run_outputs,

+ 9 - 3
api/core/ops/ops_trace_manager.py

@@ -271,6 +271,7 @@ class TraceTask:
         message_id: Optional[str] = None,
         workflow_run: Optional[WorkflowRun] = None,
         conversation_id: Optional[str] = None,
+        user_id: Optional[str] = None,
         timer: Optional[Any] = None,
         **kwargs
     ):
@@ -278,6 +279,7 @@ class TraceTask:
         self.message_id = message_id
         self.workflow_run = workflow_run
         self.conversation_id = conversation_id
+        self.user_id = user_id
         self.timer = timer
         self.kwargs = kwargs
         self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
@@ -290,7 +292,9 @@ class TraceTask:
     def preprocess(self):
         preprocess_map = {
             TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
-            TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(self.workflow_run, self.conversation_id),
+            TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
+                self.workflow_run, self.conversation_id, self.user_id
+            ),
             TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
             TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
                 self.message_id, self.timer, **self.kwargs
@@ -313,7 +317,7 @@ class TraceTask:
     def conversation_trace(self, **kwargs):
         return kwargs
 
-    def workflow_trace(self, workflow_run: WorkflowRun, conversation_id):
+    def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id):
         workflow_id = workflow_run.workflow_id
         tenant_id = workflow_run.tenant_id
         workflow_run_id = workflow_run.id
@@ -358,6 +362,7 @@ class TraceTask:
             "total_tokens": total_tokens,
             "file_list": file_list,
             "triggered_form": workflow_run.triggered_from,
+            "user_id": user_id,
         }
 
         workflow_trace_info = WorkflowTraceInfo(
@@ -654,10 +659,11 @@ trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
 
 
 class TraceQueueManager:
-    def __init__(self, app_id=None):
+    def __init__(self, app_id=None, user_id=None):
         global trace_manager_timer
 
         self.app_id = app_id
+        self.user_id = user_id
         self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
         self.flask_app = current_app._get_current_object()
         if trace_manager_timer is None: