Преглед на файлове

feat: optimize template parse (#460)

John Wang преди 1 година
родител
ревизия
c720f831af

+ 0 - 4
api/core/__init__.py

@@ -3,7 +3,6 @@ from typing import Optional
 
 import langchain
 from flask import Flask
-from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
 from pydantic import BaseModel
 
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
@@ -22,9 +21,6 @@ hosted_llm_credentials = HostedLLMCredentials()
 
 
 def init_app(app: Flask):
-    formatter = OneLineFormatter()
-    DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
-
     if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
         langchain.verbose = True
 

+ 24 - 23
api/core/completion.py

@@ -23,7 +23,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
 from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
     ReadOnlyConversationTokenDBStringBufferSharedMemory
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
 from models.model import App, AppModelConfig, Account, Conversation, Message
 
@@ -35,6 +35,8 @@ class Completion:
         """
         errors: ProviderTokenNotInitError
         """
+        query = PromptBuilder.process_template(query)
+
         memory = None
         if conversation:
             # get memory of conversation (read-only)
@@ -141,18 +143,17 @@ class Completion:
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
             Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
         # disable template string in query
-        query_params = OutLinePromptTemplate.from_template(template=query).input_variables
-        if query_params:
-            for query_param in query_params:
-                if query_param not in inputs:
-                    inputs[query_param] = '{' + query_param + '}'
+        # query_params = JinjaPromptTemplate.from_template(template=query).input_variables
+        # if query_params:
+        #     for query_param in query_params:
+        #         if query_param not in inputs:
+        #             inputs[query_param] = '{{' + query_param + '}}'
 
-        pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
         if mode == 'completion':
-            prompt_template = OutLinePromptTemplate.from_template(
+            prompt_template = JinjaPromptTemplate.from_template(
                 template=("""Use the following CONTEXT as your learned knowledge:
 [CONTEXT]
-{context}
+{{context}}
 [END CONTEXT]
 
 When answer to user:
@@ -162,16 +163,16 @@ Avoid mentioning that you obtained the information from the context.
 And answer according to the language of the user's question.
 """ if chain_output else "")
                          + (pre_prompt + "\n" if pre_prompt else "")
-                         + "{query}\n"
+                         + "{{query}}\n"
             )
 
             if chain_output:
                 inputs['context'] = chain_output
-                context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
-                if context_params:
-                    for context_param in context_params:
-                        if context_param not in inputs:
-                            inputs[context_param] = '{' + context_param + '}'
+                # context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
+                # if context_params:
+                #     for context_param in context_params:
+                #         if context_param not in inputs:
+                #             inputs[context_param] = '{{' + context_param + '}}'
 
             prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
             prompt_content = prompt_template.format(
@@ -195,7 +196,7 @@ And answer according to the language of the user's question.
 
             if pre_prompt:
                 pre_prompt_inputs = {k: inputs[k] for k in
-                                     OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
+                                     JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
                                      if k in inputs}
 
                 if pre_prompt_inputs:
@@ -205,7 +206,7 @@ And answer according to the language of the user's question.
                 human_inputs['context'] = chain_output
                 human_message_prompt += """Use the following CONTEXT as your learned knowledge.
 [CONTEXT]
-{context}
+{{context}}
 [END CONTEXT]
 
 When answer to user:
@@ -218,7 +219,7 @@ And answer according to the language of the user's question.
             if pre_prompt:
                 human_message_prompt += pre_prompt
 
-            query_prompt = "\nHuman: {query}\nAI: "
+            query_prompt = "\nHuman: {{query}}\nAI: "
 
             if memory:
                 # append chat histories
@@ -234,11 +235,11 @@ And answer according to the language of the user's question.
                 histories = cls.get_history_messages_from_memory(memory, rest_tokens)
 
                 # disable template string in query
-                histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
-                if histories_params:
-                    for histories_param in histories_params:
-                        if histories_param not in human_inputs:
-                            human_inputs[histories_param] = '{' + histories_param + '}'
+                # histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
+                # if histories_params:
+                #     for histories_param in histories_params:
+                #         if histories_param not in human_inputs:
+                #             human_inputs[histories_param] = '{{' + histories_param + '}}'
 
                 human_message_prompt += "\n\n" + histories
 

+ 3 - 4
api/core/conversation_message_task.py

@@ -10,7 +10,7 @@ from core.constant import llm_constant
 from core.llm.llm_builder import LLMBuilder
 from core.llm.provider.llm_provider_service import LLMProviderService
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate
 from events.message_event import message_was_created
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
@@ -78,7 +78,7 @@ class ConversationMessageTask:
         if self.mode == 'chat':
             introduction = self.app_model_config.opening_statement
             if introduction:
-                prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction))
+                prompt_template = JinjaPromptTemplate.from_template(template=introduction)
                 prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
                 try:
                     introduction = prompt_template.format(**prompt_inputs)
@@ -86,8 +86,7 @@ class ConversationMessageTask:
                     pass
 
             if self.app_model_config.pre_prompt:
-                pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt)
-                system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs)
+                system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
                 system_instruction = system_message.content
                 llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
                 system_instruction_tokens = llm.get_messages_tokens([system_message])

+ 4 - 3
api/core/generator/llm_generator.py

@@ -1,5 +1,6 @@
 import logging
 
+from langchain import PromptTemplate
 from langchain.chat_models.base import BaseChatModel
 from langchain.schema import HumanMessage, OutputParserException
 
@@ -10,7 +11,7 @@ from core.llm.token_calculator import TokenCalculator
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
 from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
 
 
@@ -91,8 +92,8 @@ class LLMGenerator:
         output_parser = SuggestedQuestionsAfterAnswerOutputParser()
         format_instructions = output_parser.get_format_instructions()
 
-        prompt = OutLinePromptTemplate(
-            template="{histories}\n{format_instructions}\nquestions:\n",
+        prompt = JinjaPromptTemplate(
+            template="{{histories}}\n{{format_instructions}}\nquestions:\n",
             input_variables=["histories"],
             partial_variables={"format_instructions": format_instructions}
         )

+ 7 - 6
api/core/prompt/prompt_builder.py

@@ -3,13 +3,13 @@ import re
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
 from langchain.schema import BaseMessage
 
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate
 
 
 class PromptBuilder:
     @classmethod
     def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = OutLinePromptTemplate.from_template(prompt_content)
+        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
         system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
         prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
         system_message = system_prompt_template.format(**prompt_inputs)
@@ -17,7 +17,7 @@ class PromptBuilder:
 
     @classmethod
     def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = OutLinePromptTemplate.from_template(prompt_content)
+        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
         ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
         prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
         ai_message = ai_prompt_template.format(**prompt_inputs)
@@ -25,13 +25,14 @@ class PromptBuilder:
 
     @classmethod
     def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = OutLinePromptTemplate.from_template(prompt_content)
+        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
         human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
         human_message = human_prompt_template.format(**inputs)
         return human_message
 
     @classmethod
     def process_template(cls, template: str):
-        processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
-        processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
+        processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
+        # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
+        # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
         return processed_template

+ 41 - 0
api/core/prompt/prompt_template.py

@@ -1,10 +1,33 @@
 import re
 from typing import Any
 
+from jinja2 import Environment, meta
 from langchain import PromptTemplate
 from langchain.formatting import StrictFormatter
 
 
+class JinjaPromptTemplate(PromptTemplate):
+    template_format: str = "jinja2"
+    """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
+
+    @classmethod
+    def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
+        """Load a prompt template from a template."""
+        env = Environment()
+        ast = env.parse(template)
+        input_variables = meta.find_undeclared_variables(ast)
+
+        if "partial_variables" in kwargs:
+            partial_variables = kwargs["partial_variables"]
+            input_variables = {
+                var for var in input_variables if var not in partial_variables
+            }
+
+        return cls(
+            input_variables=list(sorted(input_variables)), template=template, **kwargs
+        )
+
+
 class OutLinePromptTemplate(PromptTemplate):
     @classmethod
     def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
@@ -16,6 +39,24 @@ class OutLinePromptTemplate(PromptTemplate):
             input_variables=list(sorted(input_variables)), template=template, **kwargs
         )
 
+    def format(self, **kwargs: Any) -> str:
+        """Format the prompt with the inputs.
+
+        Args:
+            kwargs: Any arguments to be passed to the prompt template.
+
+        Returns:
+            A formatted string.
+
+        Example:
+
+        .. code-block:: python
+
+            prompt.format(variable1="foo")
+        """
+        kwargs = self._merge_partial_and_user_variables(**kwargs)
+        return OneLineFormatter().format(self.template, **kwargs)
+
 
 class OneLineFormatter(StrictFormatter):
     def parse(self, format_string):

+ 4 - 4
api/core/prompt/prompts.py

@@ -1,5 +1,5 @@
 CONVERSATION_TITLE_PROMPT = (
-    "Human:{query}\n-----\n"
+    "Human:{{query}}\n-----\n"
     "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
     "If the human said is conducted in Chinese, you should return a Chinese title.\n" 
     "If the human said is conducted in English, you should return an English title.\n"
@@ -19,7 +19,7 @@ CONVERSATION_SUMMARY_PROMPT = (
 INTRODUCTION_GENERATE_PROMPT = (
     "I am designing a product for users to interact with an AI through dialogue. "
     "The Prompt given to the AI before the conversation is:\n\n"
-    "```\n{prompt}\n```\n\n"
+    "```\n{{prompt}}\n```\n\n"
     "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
     "Do not reveal the developer's motivation or deep logic behind the Prompt, "
     "but focus on building a relationship with the user:\n"
@@ -27,13 +27,13 @@ INTRODUCTION_GENERATE_PROMPT = (
 
 MORE_LIKE_THIS_GENERATE_PROMPT = (
     "-----\n"
-    "{original_completion}\n"
+    "{{original_completion}}\n"
     "-----\n\n"
     "Please use the above content as a sample for generating the result, "
     "and include key information points related to the original sample in the result. "
     "Try to rephrase this information in different ways and predict according to the rules below.\n\n"
     "-----\n"
-    "{prompt}\n"
+    "{{prompt}}\n"
 )
 
 SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (