Sfoglia il codice sorgente

feat: implement function dispatch table for trace processing (#6628)

Joe 8 mesi fa
parent
commit
f78d0082ae
1 ha cambiato i file con 18 aggiunte e 23 eliminazioni
  1. 18 23
      api/core/ops/ops_trace_manager.py

+ 18 - 23
api/core/ops/ops_trace_manager.py

@@ -298,34 +298,29 @@ class TraceTask:
         self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
 
     def execute(self):
-        method_name, trace_info = self.preprocess()
-        return trace_info
+        return self.preprocess()
 
     def preprocess(self):
-        if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
-            return TraceTaskName.CONVERSATION_TRACE, self.conversation_trace(**self.kwargs)
-        if self.trace_type == TraceTaskName.WORKFLOW_TRACE:
-            return TraceTaskName.WORKFLOW_TRACE, self.workflow_trace(self.workflow_run, self.conversation_id)
-        elif self.trace_type == TraceTaskName.MESSAGE_TRACE:
-            return TraceTaskName.MESSAGE_TRACE, self.message_trace(self.message_id)
-        elif self.trace_type == TraceTaskName.MODERATION_TRACE:
-            return TraceTaskName.MODERATION_TRACE, self.moderation_trace(self.message_id, self.timer, **self.kwargs)
-        elif self.trace_type == TraceTaskName.SUGGESTED_QUESTION_TRACE:
-            return TraceTaskName.SUGGESTED_QUESTION_TRACE, self.suggested_question_trace(
+        preprocess_map = {
+            TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
+            TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(self.workflow_run, self.conversation_id),
+            TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
+            TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
                 self.message_id, self.timer, **self.kwargs
-            )
-        elif self.trace_type == TraceTaskName.DATASET_RETRIEVAL_TRACE:
-            return TraceTaskName.DATASET_RETRIEVAL_TRACE, self.dataset_retrieval_trace(
+            ),
+            TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
                 self.message_id, self.timer, **self.kwargs
-            )
-        elif self.trace_type == TraceTaskName.TOOL_TRACE:
-            return TraceTaskName.TOOL_TRACE, self.tool_trace(self.message_id, self.timer, **self.kwargs)
-        elif self.trace_type == TraceTaskName.GENERATE_NAME_TRACE:
-            return TraceTaskName.GENERATE_NAME_TRACE, self.generate_name_trace(
+            ),
+            TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
+                self.message_id, self.timer, **self.kwargs
+            ),
+            TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
+            TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
                 self.conversation_id, self.timer, **self.kwargs
-            )
-        else:
-            return '', {}
+            ),
+        }
+
+        return preprocess_map.get(self.trace_type, lambda: None)()
 
     # process methods for different trace types
     def conversation_trace(self, **kwargs):