|
@@ -1,4 +1,6 @@
|
|
|
+import json
|
|
|
from collections.abc import Generator
|
|
|
+from copy import deepcopy
|
|
|
from typing import Optional, cast
|
|
|
|
|
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
@@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
|
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|
|
-from core.workflow.entities.base_node_data_entities import BaseNodeData
|
|
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
|
|
from core.workflow.entities.variable_pool import VariablePool
|
|
|
from core.workflow.nodes.base_node import BaseNode
|
|
|
-from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
|
|
+from core.workflow.nodes.llm.entities import (
|
|
|
+ LLMNodeChatModelMessage,
|
|
|
+ LLMNodeCompletionModelPromptTemplate,
|
|
|
+ LLMNodeData,
|
|
|
+ ModelConfig,
|
|
|
+)
|
|
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|
|
from extensions.ext_database import db
|
|
|
from models.model import Conversation
|
|
@@ -39,16 +45,24 @@ class LLMNode(BaseNode):
|
|
|
:param variable_pool: variable pool
|
|
|
:return:
|
|
|
"""
|
|
|
- node_data = self.node_data
|
|
|
- node_data = cast(self._node_data_cls, node_data)
|
|
|
+ node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
|
|
|
|
|
node_inputs = None
|
|
|
process_data = None
|
|
|
|
|
|
try:
|
|
|
+ # init messages template
|
|
|
+ node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
|
|
|
+
|
|
|
# fetch variables and fetch values from variable pool
|
|
|
inputs = self._fetch_inputs(node_data, variable_pool)
|
|
|
|
|
|
+ # fetch jinja2 inputs
|
|
|
+ jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
|
|
|
+
|
|
|
+ # merge inputs
|
|
|
+ inputs.update(jinja_inputs)
|
|
|
+
|
|
|
node_inputs = {}
|
|
|
|
|
|
# fetch files
|
|
@@ -183,6 +197,86 @@ class LLMNode(BaseNode):
|
|
|
usage = LLMUsage.empty_usage()
|
|
|
|
|
|
return full_text, usage
|
|
|
+
|
|
|
+ def _transform_chat_messages(self,
|
|
|
+ messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
|
|
+ ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
|
|
+ """
|
|
|
+ Transform chat messages
|
|
|
+
|
|
|
+ :param messages: chat messages
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+
|
|
|
+ if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
|
|
+ if messages.edition_type == 'jinja2':
|
|
|
+ messages.text = messages.jinja2_text
|
|
|
+
|
|
|
+ return messages
|
|
|
+
|
|
|
+ for message in messages:
|
|
|
+ if message.edition_type == 'jinja2':
|
|
|
+ message.text = message.jinja2_text
|
|
|
+
|
|
|
+ return messages
|
|
|
+
|
|
|
+ def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
|
|
+ """
|
|
|
+ Fetch jinja inputs
|
|
|
+ :param node_data: node data
|
|
|
+ :param variable_pool: variable pool
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ variables = {}
|
|
|
+
|
|
|
+ if not node_data.prompt_config:
|
|
|
+ return variables
|
|
|
+
|
|
|
+ for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
|
|
+ variable = variable_selector.variable
|
|
|
+ value = variable_pool.get_variable_value(
|
|
|
+ variable_selector=variable_selector.value_selector
|
|
|
+ )
|
|
|
+
|
|
|
+ def parse_dict(d: dict) -> str:
|
|
|
+ """
|
|
|
+ Parse dict into string
|
|
|
+ """
|
|
|
+ # check if it's a context structure
|
|
|
+ if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
|
|
+ return d['content']
|
|
|
+
|
|
|
+ # else, parse the dict
|
|
|
+ try:
|
|
|
+ return json.dumps(d, ensure_ascii=False)
|
|
|
+ except Exception:
|
|
|
+ return str(d)
|
|
|
+
|
|
|
+ if isinstance(value, str):
|
|
|
+ value = value
|
|
|
+ elif isinstance(value, list):
|
|
|
+ result = ''
|
|
|
+ for item in value:
|
|
|
+ if isinstance(item, dict):
|
|
|
+ result += parse_dict(item)
|
|
|
+ elif isinstance(item, str):
|
|
|
+ result += item
|
|
|
+ elif isinstance(item, int | float):
|
|
|
+ result += str(item)
|
|
|
+ else:
|
|
|
+ result += str(item)
|
|
|
+ result += '\n'
|
|
|
+ value = result.strip()
|
|
|
+ elif isinstance(value, dict):
|
|
|
+ value = parse_dict(value)
|
|
|
+ elif isinstance(value, int | float):
|
|
|
+ value = str(value)
|
|
|
+ else:
|
|
|
+ value = str(value)
|
|
|
+
|
|
|
+ variables[variable] = value
|
|
|
+
|
|
|
+ return variables
|
|
|
|
|
|
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
|
|
"""
|
|
@@ -531,25 +625,25 @@ class LLMNode(BaseNode):
|
|
|
db.session.commit()
|
|
|
|
|
|
@classmethod
|
|
|
- def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
|
|
+ def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
|
|
|
"""
|
|
|
Extract variable selector to variable mapping
|
|
|
:param node_data: node data
|
|
|
:return:
|
|
|
"""
|
|
|
- node_data = node_data
|
|
|
- node_data = cast(cls._node_data_cls, node_data)
|
|
|
|
|
|
prompt_template = node_data.prompt_template
|
|
|
|
|
|
variable_selectors = []
|
|
|
if isinstance(prompt_template, list):
|
|
|
for prompt in prompt_template:
|
|
|
- variable_template_parser = VariableTemplateParser(template=prompt.text)
|
|
|
- variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
|
|
+ if prompt.edition_type != 'jinja2':
|
|
|
+ variable_template_parser = VariableTemplateParser(template=prompt.text)
|
|
|
+ variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
|
|
else:
|
|
|
- variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
|
|
- variable_selectors = variable_template_parser.extract_variable_selectors()
|
|
|
+ if prompt_template.edition_type != 'jinja2':
|
|
|
+ variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
|
|
+ variable_selectors = variable_template_parser.extract_variable_selectors()
|
|
|
|
|
|
variable_mapping = {}
|
|
|
for variable_selector in variable_selectors:
|
|
@@ -571,6 +665,22 @@ class LLMNode(BaseNode):
|
|
|
if node_data.memory:
|
|
|
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
|
|
|
|
|
+ if node_data.prompt_config:
|
|
|
+ enable_jinja = False
|
|
|
+
|
|
|
+ if isinstance(prompt_template, list):
|
|
|
+ for prompt in prompt_template:
|
|
|
+ if prompt.edition_type == 'jinja2':
|
|
|
+ enable_jinja = True
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ if prompt_template.edition_type == 'jinja2':
|
|
|
+ enable_jinja = True
|
|
|
+
|
|
|
+ if enable_jinja:
|
|
|
+ for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
|
|
+ variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
|
|
+
|
|
|
return variable_mapping
|
|
|
|
|
|
@classmethod
|
|
@@ -588,7 +698,8 @@ class LLMNode(BaseNode):
|
|
|
"prompts": [
|
|
|
{
|
|
|
"role": "system",
|
|
|
- "text": "You are a helpful AI assistant."
|
|
|
+ "text": "You are a helpful AI assistant.",
|
|
|
+ "edition_type": "basic"
|
|
|
}
|
|
|
]
|
|
|
},
|
|
@@ -600,7 +711,8 @@ class LLMNode(BaseNode):
|
|
|
"prompt": {
|
|
|
"text": "Here is the chat histories between human and assistant, inside "
|
|
|
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
|
|
- "#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
|
|
|
+ "#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
|
|
+ "edition_type": "basic"
|
|
|
},
|
|
|
"stop": ["Human:"]
|
|
|
}
|