Jelajahi Sumber

feat: support LLM jinja2 template prompt (#3968)

Co-authored-by: Joel <iamjoel007@gmail.com>
Yeuoly 11 bulan lalu
induk
melakukan
8578ee0864

+ 17 - 0
api/core/helper/code_executor/jinja2_formatter.py

@@ -0,0 +1,17 @@
+from core.helper.code_executor.code_executor import CodeExecutor
+
+
+class Jinja2Formatter:
+    @classmethod
+    def format(cls, template: str, inputs: str) -> str:
+        """
+        Format template
+        :param template: template
+        :param inputs: inputs
+        :return:
+        """
+        result = CodeExecutor.execute_workflow_code_template(
+            language='jinja2', code=template, inputs=inputs
+        )
+
+        return result['result']

+ 40 - 25
api/core/prompt/advanced_prompt_transform.py

@@ -2,6 +2,7 @@ from typing import Optional, Union
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.file.file_obj import FileVar
+from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
@@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform):
 
         prompt_messages = []
 
-        prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
-        prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+        if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
+            prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
 
-        prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+            prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
 
-        if memory and memory_config:
-            role_prefix = memory_config.role_prefix
-            prompt_inputs = self._set_histories_variable(
-                memory=memory,
-                memory_config=memory_config,
-                raw_prompt=raw_prompt,
-                role_prefix=role_prefix,
-                prompt_template=prompt_template,
-                prompt_inputs=prompt_inputs,
-                model_config=model_config
-            )
+            if memory and memory_config:
+                role_prefix = memory_config.role_prefix
+                prompt_inputs = self._set_histories_variable(
+                    memory=memory,
+                    memory_config=memory_config,
+                    raw_prompt=raw_prompt,
+                    role_prefix=role_prefix,
+                    prompt_template=prompt_template,
+                    prompt_inputs=prompt_inputs,
+                    model_config=model_config
+                )
 
-        if query:
-            prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
+            if query:
+                prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
 
-        prompt = prompt_template.format(
-            prompt_inputs
-        )
+            prompt = prompt_template.format(
+                prompt_inputs
+            )
+        else:
+            prompt = raw_prompt
+            prompt_inputs = inputs
+
+            prompt = Jinja2Formatter.format(prompt, prompt_inputs)
 
         if files:
             prompt_message_contents = [TextPromptMessageContent(data=prompt)]
@@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform):
         for prompt_item in raw_prompt_list:
             raw_prompt = prompt_item.text
 
-            prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
-            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+            if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
+                prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+                prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
 
-            prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+                prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
 
-            prompt = prompt_template.format(
-                prompt_inputs
-            )
+                prompt = prompt_template.format(
+                    prompt_inputs
+                )
+            elif prompt_item.edition_type == 'jinja2':
+                prompt = raw_prompt
+                prompt_inputs = inputs
+
+                prompt = Jinja2Formatter.format(prompt, prompt_inputs)
+            else:
+                raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
 
             if prompt_item.role == PromptMessageRole.USER:
                 prompt_messages.append(UserPromptMessage(content=prompt))

+ 3 - 1
api/core/prompt/entities/advanced_prompt_entities.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Literal, Optional
 
 from pydantic import BaseModel
 
@@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel):
     """
     text: str
     role: PromptMessageRole
+    edition_type: Optional[Literal['basic', 'jinja2']]
 
 
 class CompletionModelPromptTemplate(BaseModel):
@@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
     Completion Model Prompt Template.
     """
     text: str
+    edition_type: Optional[Literal['basic', 'jinja2']]
 
 
 class MemoryConfig(BaseModel):

+ 20 - 1
api/core/workflow/nodes/llm/entities.py

@@ -4,6 +4,7 @@ from pydantic import BaseModel
 
 from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
 from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.variable_entities import VariableSelector
 
 
 class ModelConfig(BaseModel):
@@ -37,13 +38,31 @@ class VisionConfig(BaseModel):
     enabled: bool
     configs: Optional[Configs] = None
 
+class PromptConfig(BaseModel):
+    """
+    Prompt Config.
+    """
+    jinja2_variables: Optional[list[VariableSelector]] = None
+
+class LLMNodeChatModelMessage(ChatModelMessage):
+    """
+    LLM Node Chat Model Message.
+    """
+    jinja2_text: Optional[str] = None
+
+class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
+    """
+    LLM Node Chat Model Prompt Template.
+    """
+    jinja2_text: Optional[str] = None
 
 class LLMNodeData(BaseNodeData):
     """
     LLM Node Data.
     """
     model: ModelConfig
-    prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
+    prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
+    prompt_config: Optional[PromptConfig] = None
     memory: Optional[MemoryConfig] = None
     context: ContextConfig
     vision: VisionConfig

+ 125 - 13
api/core/workflow/nodes/llm/llm_node.py

@@ -1,4 +1,6 @@
+import json
 from collections.abc import Generator
+from copy import deepcopy
 from typing import Optional, cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
 from core.prompt.utils.prompt_message_util import PromptMessageUtil
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
-from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
+from core.workflow.nodes.llm.entities import (
+    LLMNodeChatModelMessage,
+    LLMNodeCompletionModelPromptTemplate,
+    LLMNodeData,
+    ModelConfig,
+)
 from core.workflow.utils.variable_template_parser import VariableTemplateParser
 from extensions.ext_database import db
 from models.model import Conversation
@@ -39,16 +45,24 @@ class LLMNode(BaseNode):
         :param variable_pool: variable pool
         :return:
         """
-        node_data = self.node_data
-        node_data = cast(self._node_data_cls, node_data)
+        node_data = cast(LLMNodeData, deepcopy(self.node_data))
 
         node_inputs = None
         process_data = None
 
         try:
+            # init messages template
+            node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
+
             # fetch variables and fetch values from variable pool
             inputs = self._fetch_inputs(node_data, variable_pool)
 
+            # fetch jinja2 inputs
+            jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
+
+            # merge inputs
+            inputs.update(jinja_inputs)
+
             node_inputs = {}
 
             # fetch files
@@ -183,6 +197,86 @@ class LLMNode(BaseNode):
             usage = LLMUsage.empty_usage()
 
         return full_text, usage
+    
+    def _transform_chat_messages(self, 
+        messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
+    ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
+        """
+        Transform chat messages
+
+        :param messages: chat messages
+        :return:
+        """
+
+        if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
+            if messages.edition_type == 'jinja2':
+                messages.text = messages.jinja2_text
+
+            return messages
+
+        for message in messages:
+            if message.edition_type == 'jinja2':
+                message.text = message.jinja2_text
+
+        return messages
+
+    def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
+        """
+        Fetch jinja inputs
+        :param node_data: node data
+        :param variable_pool: variable pool
+        :return:
+        """
+        variables = {}
+
+        if not node_data.prompt_config:
+            return variables
+
+        for variable_selector in node_data.prompt_config.jinja2_variables or []:
+            variable = variable_selector.variable
+            value = variable_pool.get_variable_value(
+                variable_selector=variable_selector.value_selector
+            )
+
+            def parse_dict(d: dict) -> str:
+                """
+                Parse dict into string
+                """
+                # check if it's a context structure
+                if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
+                    return d['content']
+                
+                # else, parse the dict
+                try:
+                    return json.dumps(d, ensure_ascii=False)
+                except Exception:
+                    return str(d)
+                
+            if isinstance(value, str):
+                value = value
+            elif isinstance(value, list):
+                result = ''
+                for item in value:
+                    if isinstance(item, dict):
+                        result += parse_dict(item)
+                    elif isinstance(item, str):
+                        result += item
+                    elif isinstance(item, int | float):
+                        result += str(item)
+                    else:
+                        result += str(item)
+                    result += '\n'
+                value = result.strip()
+            elif isinstance(value, dict):
+                value = parse_dict(value)
+            elif isinstance(value, int | float):
+                value = str(value)
+            else:
+                value = str(value)
+
+            variables[variable] = value
+
+        return variables
 
     def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
         """
@@ -531,25 +625,25 @@ class LLMNode(BaseNode):
             db.session.commit()
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
         """
         Extract variable selector to variable mapping
         :param node_data: node data
         :return:
         """
-        node_data = node_data
-        node_data = cast(cls._node_data_cls, node_data)
 
         prompt_template = node_data.prompt_template
 
         variable_selectors = []
         if isinstance(prompt_template, list):
             for prompt in prompt_template:
-                variable_template_parser = VariableTemplateParser(template=prompt.text)
-                variable_selectors.extend(variable_template_parser.extract_variable_selectors())
+                if prompt.edition_type != 'jinja2':
+                    variable_template_parser = VariableTemplateParser(template=prompt.text)
+                    variable_selectors.extend(variable_template_parser.extract_variable_selectors())
         else:
-            variable_template_parser = VariableTemplateParser(template=prompt_template.text)
-            variable_selectors = variable_template_parser.extract_variable_selectors()
+            if prompt_template.edition_type != 'jinja2':
+                variable_template_parser = VariableTemplateParser(template=prompt_template.text)
+                variable_selectors = variable_template_parser.extract_variable_selectors()
 
         variable_mapping = {}
         for variable_selector in variable_selectors:
@@ -571,6 +665,22 @@ class LLMNode(BaseNode):
         if node_data.memory:
             variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
 
+        if node_data.prompt_config:
+            enable_jinja = False
+
+            if isinstance(prompt_template, list):
+                for prompt in prompt_template:
+                    if prompt.edition_type == 'jinja2':
+                        enable_jinja = True
+                        break
+            else:
+                if prompt_template.edition_type == 'jinja2':
+                    enable_jinja = True
+
+            if enable_jinja:
+                for variable_selector in node_data.prompt_config.jinja2_variables or []:
+                    variable_mapping[variable_selector.variable] = variable_selector.value_selector
+
         return variable_mapping
 
     @classmethod
@@ -588,7 +698,8 @@ class LLMNode(BaseNode):
                         "prompts": [
                             {
                                 "role": "system",
-                                "text": "You are a helpful AI assistant."
+                                "text": "You are a helpful AI assistant.",
+                                "edition_type": "basic"
                             }
                         ]
                     },
@@ -600,7 +711,8 @@ class LLMNode(BaseNode):
                         "prompt": {
                             "text": "Here is the chat histories between human and assistant, inside "
                                     "<histories></histories> XML tags.\n\n<histories>\n{{"
-                                    "#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
+                                    "#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
+                            "edition_type": "basic"
                         },
                         "stop": ["Human:"]
                     }

+ 1 - 0
api/requirements-dev.txt

@@ -3,3 +3,4 @@ pytest~=8.1.1
 pytest-benchmark~=4.0.0
 pytest-env~=1.1.3
 pytest-mock~=3.14.0
+jinja2~=3.1.2

+ 2 - 1
api/tests/integration_tests/workflow/nodes/__mock/code_executor.py

@@ -3,6 +3,7 @@ from typing import Literal
 
 import pytest
 from _pytest.monkeypatch import MonkeyPatch
+from jinja2 import Template
 
 from core.helper.code_executor.code_executor import CodeExecutor
 
@@ -18,7 +19,7 @@ class MockedCodeExecutor:
             }
         elif language == 'jinja2':
             return {
-                "result": "3"
+                "result": Template(code).render(inputs)
             }
 
 @pytest.fixture

+ 117 - 0
api/tests/integration_tests/workflow/nodes/test_llm.py

@@ -1,3 +1,4 @@
+import json
 import os
 from unittest.mock import MagicMock
 
@@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus
 
 """FOR MOCK FIXTURES, DO NOT REMOVE"""
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
 
 
 @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@@ -116,3 +118,118 @@ def test_execute_llm(setup_openai_mock):
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
     assert result.outputs['text'] is not None
     assert result.outputs['usage']['total_tokens'] > 0
+
+@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
+@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
+    """
+    Test execute LLM node with jinja2
+    """
+    node = LLMNode(
+        tenant_id='1',
+        app_id='1',
+        workflow_id='1',
+        user_id='1',
+        user_from=UserFrom.ACCOUNT,
+        config={
+            'id': 'llm',
+            'data': {
+                'title': '123',
+                'type': 'llm',
+                'model': {
+                    'provider': 'openai',
+                    'name': 'gpt-3.5-turbo',
+                    'mode': 'chat',
+                    'completion_params': {}
+                },
+                'prompt_config': {
+                    'jinja2_variables': [{
+                        'variable': 'sys_query',
+                        'value_selector': ['sys', 'query']
+                    }, {
+                        'variable': 'output',
+                        'value_selector': ['abc', 'output']
+                    }]
+                },
+                'prompt_template': [
+                    {
+                        'role': 'system',
+                        'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
+                        'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
+                        'edition_type': 'jinja2'
+                    },
+                    {
+                        'role': 'user',
+                        'text': '{{#sys.query#}}',
+                        'jinja2_text': '{{sys_query}}',
+                        'edition_type': 'basic'
+                    }
+                ],
+                'memory': None,
+                'context': {
+                    'enabled': False
+                },
+                'vision': {
+                    'enabled': False
+                }
+            }
+        }
+    )
+
+    # construct variable pool
+    pool = VariablePool(system_variables={
+        SystemVariable.QUERY: 'what\'s the weather today?',
+        SystemVariable.FILES: [],
+        SystemVariable.CONVERSATION_ID: 'abababa',
+        SystemVariable.USER_ID: 'aaa'
+    }, user_inputs={})
+    pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
+
+    credentials = {
+        'openai_api_key': os.environ.get('OPENAI_API_KEY')
+    }
+
+    provider_instance = ModelProviderFactory().get_provider_instance('openai')
+    model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+    provider_model_bundle = ProviderModelBundle(
+        configuration=ProviderConfiguration(
+            tenant_id='1',
+            provider=provider_instance.get_provider_schema(),
+            preferred_provider_type=ProviderType.CUSTOM,
+            using_provider_type=ProviderType.CUSTOM,
+            system_configuration=SystemConfiguration(
+                enabled=False
+            ),
+            custom_configuration=CustomConfiguration(
+                provider=CustomProviderConfiguration(
+                    credentials=credentials
+                )
+            )
+        ),
+        provider_instance=provider_instance,
+        model_type_instance=model_type_instance
+    )
+
+    model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
+
+    model_config = ModelConfigWithCredentialsEntity(
+        model='gpt-3.5-turbo',
+        provider='openai',
+        mode='chat',
+        credentials=credentials,
+        parameters={},
+        model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
+        provider_model_bundle=provider_model_bundle
+    )
+
+    # Mock db.session.close()
+    db.session.close = MagicMock()
+
+    node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
+
+    # execute node
+    result = node.run(pool)
+
+    assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert 'sunny' in json.dumps(result.process_data)
+    assert 'what\'s the weather today?' in json.dumps(result.process_data)