Bladeren bron

support instruction in classifier node (#4913)

Jyong 10 maanden geleden
bovenliggende
commit
c7bddb637b
1 gewijzigde bestanden met toevoegingen van 34 en 1 verwijderingen
  1. 34 1
      api/core/workflow/nodes/question_classifier/question_classifier_node.py

+ 34 - 1
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -12,6 +12,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
 from core.prompt.simple_prompt_transform import ModelMode
 from core.prompt.utils.prompt_message_util import PromptMessageUtil
+from core.prompt.utils.prompt_template_parser import PromptTemplateParser
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
@@ -26,6 +27,7 @@ from core.workflow.nodes.question_classifier.template_prompts import (
     QUESTION_CLASSIFIER_USER_PROMPT_2,
     QUESTION_CLASSIFIER_USER_PROMPT_3,
 )
+from core.workflow.utils.variable_template_parser import VariableTemplateParser
 from libs.json_in_md_parser import parse_and_check_json_markdown
 from models.workflow import WorkflowNodeExecutionStatus
 
@@ -47,6 +49,9 @@ class QuestionClassifierNode(LLMNode):
         model_instance, model_config = self._fetch_model_config(node_data.model)
         # fetch memory
         memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
+        # fetch instruction
+        instruction = self._format_instruction(node_data.instruction, variable_pool)
+        node_data.instruction = instruction
         # fetch prompt messages
         prompt_messages, stop = self._fetch_prompt(
             node_data=node_data,
@@ -122,6 +127,12 @@ class QuestionClassifierNode(LLMNode):
         node_data = node_data
         node_data = cast(cls._node_data_cls, node_data)
         variable_mapping = {'query': node_data.query_variable_selector}
+        variable_selectors = []
+        if node_data.instruction:
+            variable_template_parser = VariableTemplateParser(template=node_data.instruction)
+            variable_selectors.extend(variable_template_parser.extract_variable_selectors())
+        for variable_selector in variable_selectors:
+            variable_mapping[variable_selector.variable] = variable_selector.value_selector
         return variable_mapping
 
     @classmethod
@@ -269,8 +280,30 @@ class QuestionClassifierNode(LLMNode):
                 text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
                                                                   input_text=input_text,
                                                                   categories=json.dumps(categories),
-                                                                  classification_instructions=instruction, ensure_ascii=False)
+                                                                  classification_instructions=instruction,
+                                                                  ensure_ascii=False)
             )
 
         else:
             raise ValueError(f"Model mode {model_mode} not support.")
+
+    def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
+        inputs = {}
+
+        variable_selectors = []
+        variable_template_parser = VariableTemplateParser(template=instruction)
+        variable_selectors.extend(variable_template_parser.extract_variable_selectors())
+        for variable_selector in variable_selectors:
+            variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
+            if variable_value is None:
+                raise ValueError(f'Variable {variable_selector.variable} not found')
+
+            inputs[variable_selector.variable] = variable_value
+
+        prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
+        prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+        instruction = prompt_template.format(
+            prompt_inputs
+        )
+        return instruction