|
@@ -74,6 +74,7 @@ class LLMNode(BaseNode):
|
|
|
node_data=node_data,
|
|
|
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
|
|
if node_data.memory else None,
|
|
|
+ query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
|
|
inputs=inputs,
|
|
|
files=files,
|
|
|
context=context,
|
|
@@ -209,6 +210,17 @@ class LLMNode(BaseNode):
|
|
|
|
|
|
inputs[variable_selector.variable] = variable_value
|
|
|
|
|
|
+ memory = node_data.memory
|
|
|
+ if memory and memory.query_prompt_template:
|
|
|
+ query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
|
|
+ .extract_variable_selectors())
|
|
|
+ for variable_selector in query_variable_selectors:
|
|
|
+ variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
|
|
+ if variable_value is None:
|
|
|
+ raise ValueError(f'Variable {variable_selector.variable} not found')
|
|
|
+
|
|
|
+ inputs[variable_selector.variable] = variable_value
|
|
|
+
|
|
|
return inputs
|
|
|
|
|
|
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
|
@@ -302,7 +314,8 @@ class LLMNode(BaseNode):
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
|
|
+ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
|
|
+ ModelInstance, ModelConfigWithCredentialsEntity]:
|
|
|
"""
|
|
|
Fetch model config
|
|
|
:param node_data_model: node data model
|
|
@@ -407,6 +420,7 @@ class LLMNode(BaseNode):
|
|
|
|
|
|
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
|
|
query: Optional[str],
|
|
|
+ query_prompt_template: Optional[str],
|
|
|
inputs: dict[str, str],
|
|
|
files: list[FileVar],
|
|
|
context: Optional[str],
|
|
@@ -417,6 +431,7 @@ class LLMNode(BaseNode):
|
|
|
Fetch prompt messages
|
|
|
:param node_data: node data
|
|
|
:param query: query
|
|
|
+ :param query_prompt_template: query prompt template
|
|
|
:param inputs: inputs
|
|
|
:param files: files
|
|
|
:param context: context
|
|
@@ -433,7 +448,8 @@ class LLMNode(BaseNode):
|
|
|
context=context,
|
|
|
memory_config=node_data.memory,
|
|
|
memory=memory,
|
|
|
- model_config=model_config
|
|
|
+ model_config=model_config,
|
|
|
+ query_prompt_template=query_prompt_template,
|
|
|
)
|
|
|
stop = model_config.stop
|
|
|
|
|
@@ -539,6 +555,13 @@ class LLMNode(BaseNode):
|
|
|
for variable_selector in variable_selectors:
|
|
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
|
|
|
|
|
+ memory = node_data.memory
|
|
|
+ if memory and memory.query_prompt_template:
|
|
|
+ query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
|
|
+ .extract_variable_selectors())
|
|
|
+ for variable_selector in query_variable_selectors:
|
|
|
+ variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
|
|
+
|
|
|
if node_data.context.enabled:
|
|
|
variable_mapping['#context#'] = node_data.context.variable_selector
|
|
|
|