Joe 8 місяців тому
батько
коміт
425174e82f

+ 2 - 1
api/core/app/apps/advanced_chat/app_generator.py

@@ -89,7 +89,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         )
 
         # get tracing instance
-        trace_manager = TraceQueueManager(app_id=app_model.id)
+        user_id = user.id if isinstance(user, Account) else user.session_id
+        trace_manager = TraceQueueManager(app_model.id, user_id)
 
         if invoke_from == InvokeFrom.DEBUGGER:
             # always enable retriever resource in debugger mode

+ 2 - 1
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

@@ -48,7 +48,8 @@ from core.model_runtime.entities.message_entities import (
 )
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.utils.encoders import jsonable_encoder
-from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.prompt.utils.prompt_message_util import PromptMessageUtil
 from core.prompt.utils.prompt_template_parser import PromptTemplateParser
 from events.message_event import message_was_created

+ 2 - 1
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -22,7 +22,8 @@ from core.app.entities.task_entities import (
 from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
 from core.file.file_obj import FileVar
 from core.model_runtime.utils.encoders import jsonable_encoder
-from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.tools.tool_manager import ToolManager
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
 from core.workflow.nodes.tool.entities import ToolNodeData

+ 2 - 1
api/core/callback_handler/agent_tool_callback_handler.py

@@ -4,7 +4,8 @@ from typing import Any, Optional, TextIO, Union
 
 from pydantic import BaseModel
 
-from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.tools.entities.tool_entities import ToolInvokeMessage
 
 _TEXT_COLOR_MAPPING = {

+ 2 - 1
api/core/llm_generator/llm_generator.py

@@ -14,7 +14,8 @@ from core.model_manager import ModelManager
 from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
-from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.ops.utils import measure_time
 from core.prompt.utils.prompt_template_parser import PromptTemplateParser
 

+ 2 - 1
api/core/moderation/input_moderation.py

@@ -4,7 +4,8 @@ from typing import Optional
 from core.app.app_config.entities import AppConfig
 from core.moderation.base import ModerationAction, ModerationException
 from core.moderation.factory import ModerationFactory
-from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.ops.utils import measure_time
 
 logger = logging.getLogger(__name__)

+ 13 - 1
api/core/ops/entities/trace_entity.py

@@ -1,4 +1,5 @@
 from datetime import datetime
+from enum import Enum
 from typing import Any, Optional, Union
 
 from pydantic import BaseModel, ConfigDict, field_validator
@@ -105,4 +106,15 @@ trace_info_info_map = {
     'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
     'ToolTraceInfo': ToolTraceInfo,
     'GenerateNameTraceInfo': GenerateNameTraceInfo,
-}
+}
+
+
+class TraceTaskName(str, Enum):
+    CONVERSATION_TRACE = 'conversation'
+    WORKFLOW_TRACE = 'workflow'
+    MESSAGE_TRACE = 'message'
+    MODERATION_TRACE = 'moderation'
+    SUGGESTED_QUESTION_TRACE = 'suggested_question'
+    DATASET_RETRIEVAL_TRACE = 'dataset_retrieval'
+    TOOL_TRACE = 'tool'
+    GENERATE_NAME_TRACE = 'generate_conversation_name'

+ 20 - 22
api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py

@@ -50,10 +50,11 @@ class LangfuseTrace(BaseModel):
     """
     Langfuse trace model
     """
+
     id: Optional[str] = Field(
         default=None,
         description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems "
-                    "or when creating a distributed trace. Traces are upserted on id.",
+        "or when creating a distributed trace. Traces are upserted on id.",
     )
     name: Optional[str] = Field(
         default=None,
@@ -68,7 +69,7 @@ class LangfuseTrace(BaseModel):
     metadata: Optional[dict[str, Any]] = Field(
         default=None,
         description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated "
-                    "via the API.",
+        "via the API.",
     )
     user_id: Optional[str] = Field(
         default=None,
@@ -81,22 +82,22 @@ class LangfuseTrace(BaseModel):
     version: Optional[str] = Field(
         default=None,
         description="The version of the trace type. Used to understand how changes to the trace type affect metrics. "
-                    "Useful in debugging.",
+        "Useful in debugging.",
     )
     release: Optional[str] = Field(
         default=None,
         description="The release identifier of the current deployment. Used to understand how changes of different "
-                    "deployments affect metrics. Useful in debugging.",
+        "deployments affect metrics. Useful in debugging.",
     )
     tags: Optional[list[str]] = Field(
         default=None,
         description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET "
-                    "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
+        "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
     )
     public: Optional[bool] = Field(
         default=None,
         description="You can make a trace public to share it via a public link. This allows others to view the trace "
-                    "without needing to log in or be members of your Langfuse project.",
+        "without needing to log in or be members of your Langfuse project.",
     )
 
     @field_validator("input", "output")
@@ -109,6 +110,7 @@ class LangfuseSpan(BaseModel):
     """
     Langfuse span model
     """
+
     id: Optional[str] = Field(
         default=None,
         description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.",
@@ -140,17 +142,17 @@ class LangfuseSpan(BaseModel):
     metadata: Optional[dict[str, Any]] = Field(
         default=None,
         description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
-                    "via the API.",
+        "via the API.",
     )
     level: Optional[str] = Field(
         default=None,
         description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
-                    "traces with elevated error levels and for highlighting in the UI.",
+        "traces with elevated error levels and for highlighting in the UI.",
     )
     status_message: Optional[str] = Field(
         default=None,
         description="The status message of the span. Additional field for context of the event. E.g. the error "
-                    "message of an error event.",
+        "message of an error event.",
     )
     input: Optional[Union[str, dict[str, Any], list, None]] = Field(
         default=None, description="The input of the span. Can be any JSON object."
@@ -161,7 +163,7 @@ class LangfuseSpan(BaseModel):
     version: Optional[str] = Field(
         default=None,
         description="The version of the span type. Used to understand how changes to the span type affect metrics. "
-                    "Useful in debugging.",
+        "Useful in debugging.",
     )
     parent_observation_id: Optional[str] = Field(
         default=None,
@@ -185,10 +187,9 @@ class UnitEnum(str, Enum):
 class GenerationUsage(BaseModel):
     promptTokens: Optional[int] = None
     completionTokens: Optional[int] = None
-    totalTokens: Optional[int] = None
+    total: Optional[int] = None
     input: Optional[int] = None
     output: Optional[int] = None
-    total: Optional[int] = None
     unit: Optional[UnitEnum] = None
     inputCost: Optional[float] = None
     outputCost: Optional[float] = None
@@ -224,15 +225,13 @@ class LangfuseGeneration(BaseModel):
     completion_start_time: Optional[datetime | str] = Field(
         default=None,
         description="The time at which the completion started (streaming). Set it to get latency analytics broken "
-                    "down into time until completion started and completion duration.",
+        "down into time until completion started and completion duration.",
     )
     end_time: Optional[datetime | str] = Field(
         default=None,
         description="The time at which the generation ended. Automatically set by generation.end().",
     )
-    model: Optional[str] = Field(
-        default=None, description="The name of the model used for the generation."
-    )
+    model: Optional[str] = Field(default=None, description="The name of the model used for the generation.")
     model_parameters: Optional[dict[str, Any]] = Field(
         default=None,
         description="The parameters of the model used for the generation; can be any key-value pairs.",
@@ -248,27 +247,27 @@ class LangfuseGeneration(BaseModel):
     usage: Optional[GenerationUsage] = Field(
         default=None,
         description="The usage object supports the OpenAi structure with tokens and a more generic version with "
-                    "detailed costs and units.",
+        "detailed costs and units.",
     )
     metadata: Optional[dict[str, Any]] = Field(
         default=None,
         description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being "
-                    "updated via the API.",
+        "updated via the API.",
     )
     level: Optional[LevelEnum] = Field(
         default=None,
         description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering "
-                    "of traces with elevated error levels and for highlighting in the UI.",
+        "of traces with elevated error levels and for highlighting in the UI.",
     )
     status_message: Optional[str] = Field(
         default=None,
         description="The status message of the generation. Additional field for context of the event. E.g. the error "
-                    "message of an error event.",
+        "message of an error event.",
     )
     version: Optional[str] = Field(
         default=None,
         description="The version of the generation type. Used to understand how changes to the span type affect "
-                    "metrics. Useful in debugging.",
+        "metrics. Useful in debugging.",
     )
 
     model_config = ConfigDict(protected_namespaces=())
@@ -277,4 +276,3 @@ class LangfuseGeneration(BaseModel):
     def ensure_dict(cls, v, info: ValidationInfo):
         field_name = info.field_name
         return validate_input_output(v, field_name)
-

+ 44 - 59
api/core/ops/langfuse_trace/langfuse_trace.py

@@ -16,6 +16,7 @@ from core.ops.entities.trace_entity import (
     ModerationTraceInfo,
     SuggestedQuestionTraceInfo,
     ToolTraceInfo,
+    TraceTaskName,
     WorkflowTraceInfo,
 )
 from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
@@ -68,9 +69,9 @@ class LangFuseDataTrace(BaseTraceInstance):
         user_id = trace_info.metadata.get("user_id")
         if trace_info.message_id:
             trace_id = trace_info.message_id
-            name = f"message_{trace_info.message_id}"
+            name = TraceTaskName.MESSAGE_TRACE.value
             trace_data = LangfuseTrace(
-                id=trace_info.message_id,
+                id=trace_id,
                 user_id=user_id,
                 name=name,
                 input=trace_info.workflow_run_inputs,
@@ -78,11 +79,13 @@ class LangFuseDataTrace(BaseTraceInstance):
                 metadata=trace_info.metadata,
                 session_id=trace_info.conversation_id,
                 tags=["message", "workflow"],
+                created_at=trace_info.start_time,
+                updated_at=trace_info.end_time,
             )
             self.add_trace(langfuse_trace_data=trace_data)
             workflow_span_data = LangfuseSpan(
-                id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
-                name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
+                id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id),
+                name=TraceTaskName.WORKFLOW_TRACE.value,
                 input=trace_info.workflow_run_inputs,
                 output=trace_info.workflow_run_outputs,
                 trace_id=trace_id,
@@ -97,7 +100,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             trace_data = LangfuseTrace(
                 id=trace_id,
                 user_id=user_id,
-                name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
+                name=TraceTaskName.WORKFLOW_TRACE.value,
                 input=trace_info.workflow_run_inputs,
                 output=trace_info.workflow_run_outputs,
                 metadata=trace_info.metadata,
@@ -134,14 +137,12 @@ class LangFuseDataTrace(BaseTraceInstance):
             node_type = node_execution.node_type
             status = node_execution.status
             if node_type == "llm":
-                inputs = json.loads(node_execution.process_data).get(
-                    "prompts", {}
-                    ) if node_execution.process_data else {}
+                inputs = (
+                    json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
+                )
             else:
                 inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
-            outputs = (
-                json.loads(node_execution.outputs) if node_execution.outputs else {}
-            )
+            outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
             created_at = node_execution.created_at if node_execution.created_at else datetime.now()
             elapsed_time = node_execution.elapsed_time
             finished_at = created_at + timedelta(seconds=elapsed_time)
@@ -163,28 +164,30 @@ class LangFuseDataTrace(BaseTraceInstance):
             if trace_info.message_id:
                 span_data = LangfuseSpan(
                     id=node_execution_id,
-                    name=f"{node_name}_{node_execution_id}",
+                    name=node_type,
                     input=inputs,
                     output=outputs,
                     trace_id=trace_id,
                     start_time=created_at,
                     end_time=finished_at,
                     metadata=metadata,
-                    level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
+                    level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
                     status_message=trace_info.error if trace_info.error else "",
-                    parent_observation_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
+                    parent_observation_id=(
+                        trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
+                    ),
                 )
             else:
                 span_data = LangfuseSpan(
                     id=node_execution_id,
-                    name=f"{node_name}_{node_execution_id}",
+                    name=node_type,
                     input=inputs,
                     output=outputs,
                     trace_id=trace_id,
                     start_time=created_at,
                     end_time=finished_at,
                     metadata=metadata,
-                    level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
+                    level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
                     status_message=trace_info.error if trace_info.error else "",
                 )
 
@@ -195,11 +198,11 @@ class LangFuseDataTrace(BaseTraceInstance):
                 total_token = metadata.get("total_tokens", 0)
                 # add generation
                 generation_usage = GenerationUsage(
-                    totalTokens=total_token,
+                    total=total_token,
                 )
 
                 node_generation_data = LangfuseGeneration(
-                    name=f"generation_{node_execution_id}",
+                    name="llm",
                     trace_id=trace_id,
                     parent_observation_id=node_execution_id,
                     start_time=created_at,
@@ -207,16 +210,14 @@ class LangFuseDataTrace(BaseTraceInstance):
                     input=inputs,
                     output=outputs,
                     metadata=metadata,
-                    level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
+                    level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
                     status_message=trace_info.error if trace_info.error else "",
                     usage=generation_usage,
                 )
 
                 self.add_generation(langfuse_generation_data=node_generation_data)
 
-    def message_trace(
-        self, trace_info: MessageTraceInfo, **kwargs
-    ):
+    def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
         # get message file data
         file_list = trace_info.file_list
         metadata = trace_info.metadata
@@ -225,9 +226,9 @@ class LangFuseDataTrace(BaseTraceInstance):
 
         user_id = message_data.from_account_id
         if message_data.from_end_user_id:
-            end_user_data: EndUser = db.session.query(EndUser).filter(
-                EndUser.id == message_data.from_end_user_id
-            ).first()
+            end_user_data: EndUser = (
+                db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+            )
             if end_user_data is not None:
                 user_id = end_user_data.session_id
                 metadata["user_id"] = user_id
@@ -235,7 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance):
         trace_data = LangfuseTrace(
             id=message_id,
             user_id=user_id,
-            name=f"message_{message_id}",
+            name=TraceTaskName.MESSAGE_TRACE.value,
             input={
                 "message": trace_info.inputs,
                 "files": file_list,
@@ -258,7 +259,6 @@ class LangFuseDataTrace(BaseTraceInstance):
 
         # start add span
         generation_usage = GenerationUsage(
-            totalTokens=trace_info.total_tokens,
             input=trace_info.message_tokens,
             output=trace_info.answer_tokens,
             total=trace_info.total_tokens,
@@ -267,7 +267,7 @@ class LangFuseDataTrace(BaseTraceInstance):
         )
 
         langfuse_generation_data = LangfuseGeneration(
-            name=f"generation_{message_id}",
+            name="llm",
             trace_id=message_id,
             start_time=trace_info.start_time,
             end_time=trace_info.end_time,
@@ -275,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             input=trace_info.inputs,
             output=message_data.answer,
             metadata=metadata,
-            level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR,
+            level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
             status_message=message_data.error if message_data.error else "",
             usage=generation_usage,
         )
@@ -284,7 +284,7 @@ class LangFuseDataTrace(BaseTraceInstance):
 
     def moderation_trace(self, trace_info: ModerationTraceInfo):
         span_data = LangfuseSpan(
-            name="moderation",
+            name=TraceTaskName.MODERATION_TRACE.value,
             input=trace_info.inputs,
             output={
                 "action": trace_info.action,
@@ -303,22 +303,21 @@ class LangFuseDataTrace(BaseTraceInstance):
     def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
         message_data = trace_info.message_data
         generation_usage = GenerationUsage(
-            totalTokens=len(str(trace_info.suggested_question)),
+            total=len(str(trace_info.suggested_question)),
             input=len(trace_info.inputs),
             output=len(trace_info.suggested_question),
-            total=len(trace_info.suggested_question),
             unit=UnitEnum.CHARACTERS,
         )
 
         generation_data = LangfuseGeneration(
-            name="suggested_question",
+            name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
             input=trace_info.inputs,
             output=str(trace_info.suggested_question),
             trace_id=trace_info.message_id,
             start_time=trace_info.start_time,
             end_time=trace_info.end_time,
             metadata=trace_info.metadata,
-            level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR,
+            level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
             status_message=message_data.error if message_data.error else "",
             usage=generation_usage,
         )
@@ -327,7 +326,7 @@ class LangFuseDataTrace(BaseTraceInstance):
 
     def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
         dataset_retrieval_span_data = LangfuseSpan(
-            name="dataset_retrieval",
+            name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
             input=trace_info.inputs,
             output={"documents": trace_info.documents},
             trace_id=trace_info.message_id,
@@ -347,7 +346,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             start_time=trace_info.start_time,
             end_time=trace_info.end_time,
             metadata=trace_info.metadata,
-            level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR,
+            level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR),
             status_message=trace_info.error,
         )
 
@@ -355,7 +354,7 @@ class LangFuseDataTrace(BaseTraceInstance):
 
     def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
         name_generation_trace_data = LangfuseTrace(
-            name="generate_name",
+            name=TraceTaskName.GENERATE_NAME_TRACE.value,
             input=trace_info.inputs,
             output=trace_info.outputs,
             user_id=trace_info.tenant_id,
@@ -366,7 +365,7 @@ class LangFuseDataTrace(BaseTraceInstance):
         self.add_trace(langfuse_trace_data=name_generation_trace_data)
 
         name_generation_span_data = LangfuseSpan(
-            name="generate_name",
+            name=TraceTaskName.GENERATE_NAME_TRACE.value,
             input=trace_info.inputs,
             output=trace_info.outputs,
             trace_id=trace_info.conversation_id,
@@ -377,9 +376,7 @@ class LangFuseDataTrace(BaseTraceInstance):
         self.add_span(langfuse_span_data=name_generation_span_data)
 
     def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None):
-        format_trace_data = (
-            filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
-        )
+        format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
         try:
             self.langfuse_client.trace(**format_trace_data)
             logger.debug("LangFuse Trace created successfully")
@@ -387,9 +384,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
 
     def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None):
-        format_span_data = (
-            filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
-        )
+        format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
         try:
             self.langfuse_client.span(**format_span_data)
             logger.debug("LangFuse Span created successfully")
@@ -397,19 +392,13 @@ class LangFuseDataTrace(BaseTraceInstance):
             raise ValueError(f"LangFuse Failed to create span: {str(e)}")
 
     def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None):
-        format_span_data = (
-            filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
-        )
+        format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
 
         span.end(**format_span_data)
 
-    def add_generation(
-        self, langfuse_generation_data: Optional[LangfuseGeneration] = None
-    ):
+    def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None):
         format_generation_data = (
-            filter_none_values(langfuse_generation_data.model_dump())
-            if langfuse_generation_data
-            else {}
+            filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
         )
         try:
             self.langfuse_client.generation(**format_generation_data)
@@ -417,13 +406,9 @@ class LangFuseDataTrace(BaseTraceInstance):
         except Exception as e:
             raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
 
-    def update_generation(
-        self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None
-    ):
+    def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None):
         format_generation_data = (
-            filter_none_values(langfuse_generation_data.model_dump())
-            if langfuse_generation_data
-            else {}
+            filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
         )
 
         generation.end(**format_generation_data)

+ 24 - 27
api/core/ops/langsmith_trace/langsmith_trace.py

@@ -15,6 +15,7 @@ from core.ops.entities.trace_entity import (
     ModerationTraceInfo,
     SuggestedQuestionTraceInfo,
     ToolTraceInfo,
+    TraceTaskName,
     WorkflowTraceInfo,
 )
 from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
@@ -39,9 +40,7 @@ class LangSmithDataTrace(BaseTraceInstance):
         self.langsmith_key = langsmith_config.api_key
         self.project_name = langsmith_config.project
         self.project_id = None
-        self.langsmith_client = Client(
-            api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint
-        )
+        self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
         self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
 
     def trace(self, trace_info: BaseTraceInfo):
@@ -64,7 +63,7 @@ class LangSmithDataTrace(BaseTraceInstance):
         if trace_info.message_id:
             message_run = LangSmithRunModel(
                 id=trace_info.message_id,
-                name=f"message_{trace_info.message_id}",
+                name=TraceTaskName.MESSAGE_TRACE.value,
                 inputs=trace_info.workflow_run_inputs,
                 outputs=trace_info.workflow_run_outputs,
                 run_type=LangSmithRunType.chain,
@@ -73,8 +72,8 @@ class LangSmithDataTrace(BaseTraceInstance):
                 extra={
                     "metadata": trace_info.metadata,
                 },
-                tags=["message"],
-                error=trace_info.error
+                tags=["message", "workflow"],
+                error=trace_info.error,
             )
             self.add_run(message_run)
 
@@ -82,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance):
             file_list=trace_info.file_list,
             total_tokens=trace_info.total_tokens,
             id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
-            name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
+            name=TraceTaskName.WORKFLOW_TRACE.value,
             inputs=trace_info.workflow_run_inputs,
             run_type=LangSmithRunType.tool,
             start_time=trace_info.workflow_data.created_at,
@@ -126,22 +125,18 @@ class LangSmithDataTrace(BaseTraceInstance):
             node_type = node_execution.node_type
             status = node_execution.status
             if node_type == "llm":
-                inputs = json.loads(node_execution.process_data).get(
-                    "prompts", {}
-                    ) if node_execution.process_data else {}
+                inputs = (
+                    json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
+                )
             else:
                 inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
-            outputs = (
-                json.loads(node_execution.outputs) if node_execution.outputs else {}
-            )
+            outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
             created_at = node_execution.created_at if node_execution.created_at else datetime.now()
             elapsed_time = node_execution.elapsed_time
             finished_at = created_at + timedelta(seconds=elapsed_time)
 
             execution_metadata = (
-                json.loads(node_execution.execution_metadata)
-                if node_execution.execution_metadata
-                else {}
+                json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
             )
             node_total_tokens = execution_metadata.get("total_tokens", 0)
 
@@ -168,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance):
 
             langsmith_run = LangSmithRunModel(
                 total_tokens=node_total_tokens,
-                name=f"{node_name}_{node_execution_id}",
+                name=node_type,
                 inputs=inputs,
                 run_type=run_type,
                 start_time=created_at,
@@ -178,7 +173,9 @@ class LangSmithDataTrace(BaseTraceInstance):
                 extra={
                     "metadata": metadata,
                 },
-                parent_run_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
+                parent_run_id=trace_info.workflow_app_log_id
+                if trace_info.workflow_app_log_id
+                else trace_info.workflow_run_id,
                 tags=["node_execution"],
             )
 
@@ -198,9 +195,9 @@ class LangSmithDataTrace(BaseTraceInstance):
         metadata["user_id"] = user_id
 
         if message_data.from_end_user_id:
-            end_user_data: EndUser = db.session.query(EndUser).filter(
-                EndUser.id == message_data.from_end_user_id
-            ).first()
+            end_user_data: EndUser = (
+                db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+            )
             if end_user_data is not None:
                 end_user_id = end_user_data.session_id
                 metadata["end_user_id"] = end_user_id
@@ -210,7 +207,7 @@ class LangSmithDataTrace(BaseTraceInstance):
             output_tokens=trace_info.answer_tokens,
             total_tokens=trace_info.total_tokens,
             id=message_id,
-            name=f"message_{message_id}",
+            name=TraceTaskName.MESSAGE_TRACE.value,
             inputs=trace_info.inputs,
             run_type=LangSmithRunType.chain,
             start_time=trace_info.start_time,
@@ -230,7 +227,7 @@ class LangSmithDataTrace(BaseTraceInstance):
             input_tokens=trace_info.message_tokens,
             output_tokens=trace_info.answer_tokens,
             total_tokens=trace_info.total_tokens,
-            name=f"llm_{message_id}",
+            name="llm",
             inputs=trace_info.inputs,
             run_type=LangSmithRunType.llm,
             start_time=trace_info.start_time,
@@ -248,7 +245,7 @@ class LangSmithDataTrace(BaseTraceInstance):
 
     def moderation_trace(self, trace_info: ModerationTraceInfo):
         langsmith_run = LangSmithRunModel(
-            name="moderation",
+            name=TraceTaskName.MODERATION_TRACE.value,
             inputs=trace_info.inputs,
             outputs={
                 "action": trace_info.action,
@@ -271,7 +268,7 @@ class LangSmithDataTrace(BaseTraceInstance):
     def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
         message_data = trace_info.message_data
         suggested_question_run = LangSmithRunModel(
-            name="suggested_question",
+            name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
             inputs=trace_info.inputs,
             outputs=trace_info.suggested_question,
             run_type=LangSmithRunType.tool,
@@ -288,7 +285,7 @@ class LangSmithDataTrace(BaseTraceInstance):
 
     def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
         dataset_retrieval_run = LangSmithRunModel(
-            name="dataset_retrieval",
+            name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
             inputs=trace_info.inputs,
             outputs={"documents": trace_info.documents},
             run_type=LangSmithRunType.retriever,
@@ -323,7 +320,7 @@ class LangSmithDataTrace(BaseTraceInstance):
 
     def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
         name_run = LangSmithRunModel(
-            name="generate_name",
+            name=TraceTaskName.GENERATE_NAME_TRACE.value,
             inputs=trace_info.inputs,
             outputs=trace_info.outputs,
             run_type=LangSmithRunType.tool,

+ 1 - 12
api/core/ops/ops_trace_manager.py

@@ -5,7 +5,6 @@ import queue
 import threading
 import time
 from datetime import timedelta
-from enum import Enum
 from typing import Any, Optional, Union
 from uuid import UUID
 
@@ -24,6 +23,7 @@ from core.ops.entities.trace_entity import (
     ModerationTraceInfo,
     SuggestedQuestionTraceInfo,
     ToolTraceInfo,
+    TraceTaskName,
     WorkflowTraceInfo,
 )
 from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
@@ -253,17 +253,6 @@ class OpsTraceManager:
         return trace_instance(tracing_config).api_check()
 
 
-class TraceTaskName(str, Enum):
-    CONVERSATION_TRACE = 'conversation_trace'
-    WORKFLOW_TRACE = 'workflow_trace'
-    MESSAGE_TRACE = 'message_trace'
-    MODERATION_TRACE = 'moderation_trace'
-    SUGGESTED_QUESTION_TRACE = 'suggested_question_trace'
-    DATASET_RETRIEVAL_TRACE = 'dataset_retrieval_trace'
-    TOOL_TRACE = 'tool_trace'
-    GENERATE_NAME_TRACE = 'generate_name_trace'
-
-
 class TraceTask:
     def __init__(
         self,

+ 2 - 1
api/core/rag/retrieval/dataset_retrieval.py

@@ -14,7 +14,8 @@ from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.ops.utils import measure_time
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler

+ 2 - 1
api/services/message_service.py

@@ -7,7 +7,8 @@ from core.llm_generator.llm_generator import LLMGenerator
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
+from core.ops.entities.trace_entity import TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.ops.utils import measure_time
 from extensions.ext_database import db
 from libs.infinite_scroll_pagination import InfiniteScrollPagination