Ver Fonte

Fix/refactor invoke result handling in question classifier node (#12015)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- há 4 meses atrás
pai
commit
c3c85276d1

+ 18 - 23
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -1,10 +1,8 @@
 import json
-import logging
 from collections.abc import Mapping, Sequence
 from typing import TYPE_CHECKING, Any, Optional, cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.llm_generator.output_parser.errors import OutputParserError
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@@ -96,27 +94,28 @@ class QuestionClassifierNode(LLMNode):
             jinja2_variables=[],
         )
 
-        # handle invoke result
-        generator = self._invoke_llm(
-            node_data_model=node_data.model,
-            model_instance=model_instance,
-            prompt_messages=prompt_messages,
-            stop=stop,
-        )
-
         result_text = ""
         usage = LLMUsage.empty_usage()
         finish_reason = None
-        for event in generator:
-            if isinstance(event, ModelInvokeCompletedEvent):
-                result_text = event.text
-                usage = event.usage
-                finish_reason = event.finish_reason
-                break
 
-        category_name = node_data.classes[0].name
-        category_id = node_data.classes[0].id
         try:
+            # handle invoke result
+            generator = self._invoke_llm(
+                node_data_model=node_data.model,
+                model_instance=model_instance,
+                prompt_messages=prompt_messages,
+                stop=stop,
+            )
+
+            for event in generator:
+                if isinstance(event, ModelInvokeCompletedEvent):
+                    result_text = event.text
+                    usage = event.usage
+                    finish_reason = event.finish_reason
+                    break
+
+            category_name = node_data.classes[0].name
+            category_id = node_data.classes[0].id
             result_text_json = parse_and_check_json_markdown(result_text, [])
             # result_text_json = json.loads(result_text.strip('```JSON\n'))
             if "category_name" in result_text_json and "category_id" in result_text_json:
@@ -127,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
                 if category_id_result in category_ids:
                     category_name = classes_map[category_id_result]
                     category_id = category_id_result
-
-        except OutputParserError:
-            logging.exception(f"Failed to parse result text: {result_text}")
-        try:
             process_data = {
                 "model_mode": model_config.mode,
                 "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@@ -154,7 +149,7 @@ class QuestionClassifierNode(LLMNode):
                 },
                 llm_usage=usage,
             )
-        except Exception as e:
+        except ValueError as e:
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,
                 inputs=variables,