Explorar el Código

feat: query prompt template support in chatflow (#3791)

Co-authored-by: Joel <iamjoel007@gmail.com>
takatost hace 1 año
padre
commit
12435774ca

+ 19 - 2
api/core/prompt/advanced_prompt_transform.py

@@ -31,7 +31,8 @@ class AdvancedPromptTransform(PromptTransform):
                    context: Optional[str],
                    memory_config: Optional[MemoryConfig],
                    memory: Optional[TokenBufferMemory],
-                   model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
+                   model_config: ModelConfigWithCredentialsEntity,
+                   query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
         inputs = {key: str(value) for key, value in inputs.items()}
 
         prompt_messages = []
@@ -53,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform):
                 prompt_template=prompt_template,
                 inputs=inputs,
                 query=query,
+                query_prompt_template=query_prompt_template,
                 files=files,
                 context=context,
                 memory_config=memory_config,
@@ -121,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform):
                                         context: Optional[str],
                                         memory_config: Optional[MemoryConfig],
                                         memory: Optional[TokenBufferMemory],
-                                        model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
+                                        model_config: ModelConfigWithCredentialsEntity,
+                                        query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
         """
         Get chat model prompt messages.
         """
@@ -148,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform):
             elif prompt_item.role == PromptMessageRole.ASSISTANT:
                 prompt_messages.append(AssistantPromptMessage(content=prompt))
 
+        if query and query_prompt_template:
+            prompt_template = PromptTemplateParser(
+                template=query_prompt_template,
+                with_variable_tmpl=self.with_variable_tmpl
+            )
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+            prompt_inputs['#sys.query#'] = query
+
+            prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+
+            query = prompt_template.format(
+                prompt_inputs
+            )
+
         if memory and memory_config:
             prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
 

+ 1 - 0
api/core/prompt/entities/advanced_prompt_entities.py

@@ -40,3 +40,4 @@ class MemoryConfig(BaseModel):
 
     role_prefix: Optional[RolePrefix] = None
     window: WindowConfig
+    query_prompt_template: Optional[str] = None

+ 25 - 2
api/core/workflow/nodes/llm/llm_node.py

@@ -74,6 +74,7 @@ class LLMNode(BaseNode):
                 node_data=node_data,
                 query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
                 if node_data.memory else None,
+                query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
                 inputs=inputs,
                 files=files,
                 context=context,
@@ -209,6 +210,17 @@ class LLMNode(BaseNode):
 
             inputs[variable_selector.variable] = variable_value
 
+        memory = node_data.memory
+        if memory and memory.query_prompt_template:
+            query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
+                                        .extract_variable_selectors())
+            for variable_selector in query_variable_selectors:
+                variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
+                if variable_value is None:
+                    raise ValueError(f'Variable {variable_selector.variable} not found')
+
+                inputs[variable_selector.variable] = variable_value
+
         return inputs
 
     def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
@@ -302,7 +314,8 @@ class LLMNode(BaseNode):
 
         return None
 
-    def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
+    def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
+        ModelInstance, ModelConfigWithCredentialsEntity]:
         """
         Fetch model config
         :param node_data_model: node data model
@@ -407,6 +420,7 @@ class LLMNode(BaseNode):
 
     def _fetch_prompt_messages(self, node_data: LLMNodeData,
                                query: Optional[str],
+                               query_prompt_template: Optional[str],
                                inputs: dict[str, str],
                                files: list[FileVar],
                                context: Optional[str],
@@ -417,6 +431,7 @@ class LLMNode(BaseNode):
         Fetch prompt messages
         :param node_data: node data
         :param query: query
+        :param query_prompt_template: query prompt template
         :param inputs: inputs
         :param files: files
         :param context: context
@@ -433,7 +448,8 @@ class LLMNode(BaseNode):
             context=context,
             memory_config=node_data.memory,
             memory=memory,
-            model_config=model_config
+            model_config=model_config,
+            query_prompt_template=query_prompt_template,
         )
         stop = model_config.stop
 
@@ -539,6 +555,13 @@ class LLMNode(BaseNode):
         for variable_selector in variable_selectors:
             variable_mapping[variable_selector.variable] = variable_selector.value_selector
 
+        memory = node_data.memory
+        if memory and memory.query_prompt_template:
+            query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
+                                        .extract_variable_selectors())
+            for variable_selector in query_variable_selectors:
+                variable_mapping[variable_selector.variable] = variable_selector.value_selector
+
         if node_data.context.enabled:
             variable_mapping['#context#'] = node_data.context.variable_selector
 

+ 3 - 0
web/app/components/base/prompt-editor/constants.tsx

@@ -30,6 +30,9 @@ export const checkHasQueryBlock = (text: string) => {
 * {{#1711617514996.sys.query#}} => [sys, query]
 */
 export const getInputVars = (text: string): ValueSelector[] => {
+  if (!text)
+    return []
+
   const allVars = text.match(/{{#([^#]*)#}}/g)
   if (allVars && allVars?.length > 0) {
     // {{#context#}}, {{#query#}} is not input vars

+ 1 - 0
web/app/components/workflow/nodes/_base/components/prompt/editor.tsx

@@ -146,6 +146,7 @@ const Editor: FC<Props> = ({
               <PromptEditor
                 instanceId={instanceId}
                 compact
+                className='min-h-[56px]'
                 style={isExpand ? { height: editorExpandHeight - 5 } : {}}
                 value={value}
                 contextBlock={{

+ 7 - 3
web/app/components/workflow/nodes/_base/components/variable/utils.ts

@@ -272,10 +272,12 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => {
       const payload = (data as LLMNodeType)
       const isChatModel = payload.model?.mode === 'chat'
       let prompts: string[] = []
-      if (isChatModel)
+      if (isChatModel) {
         prompts = (payload.prompt_template as PromptItem[])?.map(p => p.text) || []
-      else
-        prompts = [(payload.prompt_template as PromptItem).text]
+        if (payload.memory?.query_prompt_template)
+          prompts.push(payload.memory.query_prompt_template)
+      }
+      else { prompts = [(payload.prompt_template as PromptItem).text] }
 
       const inputVars: ValueSelector[] = matchNotSystemVars(prompts)
       const contextVar = (data as LLMNodeType).context?.variable_selector ? [(data as LLMNodeType).context?.variable_selector] : []
@@ -375,6 +377,8 @@ export const updateNodeVars = (oldNode: Node, oldVarSelector: ValueSelector, new
               text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector),
             }
           })
+          if (payload.memory?.query_prompt_template)
+            payload.memory.query_prompt_template = replaceOldVarInText(payload.memory.query_prompt_template, oldVarSelector, newVarSelector)
         }
         else {
           payload.prompt_template = {

+ 7 - 0
web/app/components/workflow/nodes/llm/default.ts

@@ -50,6 +50,13 @@ const nodeDefault: NodeDefault<LLMNodeType> = {
       if (isPromptyEmpty)
         errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') })
     }
+
+    if (!errorMessages && !!payload.memory) {
+      const isChatModel = payload.model.mode === 'chat'
+      // payload.memory.query_prompt_template not pass is default: {{#sys.query#}}
+      if (isChatModel && !!payload.memory.query_prompt_template && !payload.memory.query_prompt_template.includes('{{#sys.query#}}'))
+        errorMessages = t('workflow.nodes.llm.sysQueryInUser')
+    }
     return {
       isValid: !errorMessages,
       errorMessage: errorMessages,

+ 14 - 10
web/app/components/workflow/nodes/llm/panel.tsx

@@ -50,7 +50,10 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
     handleContextVarChange,
     filterInputVar,
     filterVar,
+    availableVars,
+    availableNodes,
     handlePromptChange,
+    handleSyeQueryChange,
     handleMemoryChange,
     handleVisionResolutionEnabledChange,
     handleVisionResolutionChange,
@@ -204,19 +207,20 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
                     <HelpCircle className='w-3.5 h-3.5 text-gray-400' />
                   </TooltipPlus>
                 </div>}
-                value={'{{#sys.query#}}'}
-                onChange={() => { }}
-                readOnly
+                value={inputs.memory.query_prompt_template || '{{#sys.query#}}'}
+                onChange={handleSyeQueryChange}
+                readOnly={readOnly}
                 isShowContext={false}
                 isChatApp
-                isChatModel={false}
-                hasSetBlockStatus={{
-                  query: false,
-                  history: true,
-                  context: true,
-                }}
-                availableNodes={[startNode!]}
+                isChatModel
+                hasSetBlockStatus={hasSetBlockStatus}
+                nodesOutputVars={availableVars}
+                availableNodes={availableNodes}
               />
+
+              {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
+                <div className='leading-[18px] text-xs font-normal text-[#DC6803]'>{t(`${i18nPrefix}.sysQueryInUser`)}</div>
+              )}
             </div>
           </div>
         )}

+ 33 - 1
web/app/components/workflow/nodes/llm/use-config.ts

@@ -8,6 +8,7 @@ import {
   useIsChatMode,
   useNodesReadOnly,
 } from '../../hooks'
+import useAvailableVarList from '../_base/hooks/use-available-var-list'
 import type { LLMNodeType } from './types'
 import { Resolution } from '@/types/app'
 import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
@@ -206,6 +207,24 @@ const useConfig = (id: string, payload: LLMNodeType) => {
     setInputs(newInputs)
   }, [inputs, setInputs])
 
+  const handleSyeQueryChange = useCallback((newQuery: string) => {
+    const newInputs = produce(inputs, (draft) => {
+      if (!draft.memory) {
+        draft.memory = {
+          window: {
+            enabled: false,
+            size: 10,
+          },
+          query_prompt_template: newQuery,
+        }
+      }
+      else {
+        draft.memory.query_prompt_template = newQuery
+      }
+    })
+    setInputs(newInputs)
+  }, [inputs, setInputs])
+
   const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => {
     const newInputs = produce(inputs, (draft) => {
       if (!draft.vision) {
@@ -248,6 +267,14 @@ const useConfig = (id: string, payload: LLMNodeType) => {
     return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type)
   }, [])
 
+  const {
+    availableVars,
+    availableNodes,
+  } = useAvailableVarList(id, {
+    onlyLeafNodeVar: false,
+    filterVar,
+  })
+
   // single run
   const {
     isShowSingleRun,
@@ -322,8 +349,10 @@ const useConfig = (id: string, payload: LLMNodeType) => {
 
   const allVarStrArr = (() => {
     const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text]
-    if (isChatMode && isChatModel && !!inputs.memory)
+    if (isChatMode && isChatModel && !!inputs.memory) {
       arr.push('{{#sys.query#}}')
+      arr.push(inputs.memory.query_prompt_template)
+    }
 
     return arr
   })()
@@ -346,8 +375,11 @@ const useConfig = (id: string, payload: LLMNodeType) => {
     handleContextVarChange,
     filterInputVar,
     filterVar,
+    availableVars,
+    availableNodes,
     handlePromptChange,
     handleMemoryChange,
+    handleSyeQueryChange,
     handleVisionResolutionEnabledChange,
     handleVisionResolutionChange,
     isShowSingleRun,

+ 1 - 0
web/app/components/workflow/types.ts

@@ -143,6 +143,7 @@ export type Memory = {
     enabled: boolean
     size: number | string | null
   }
+  query_prompt_template: string
 }
 
 export enum VarType {

+ 1 - 0
web/i18n/en-US/workflow.ts

@@ -204,6 +204,7 @@ const translation = {
       singleRun: {
         variable: 'Variable',
       },
+      sysQueryInUser: 'sys.query in user message is required',
     },
     knowledgeRetrieval: {
       queryVariable: 'Query Variable',

+ 1 - 0
web/i18n/zh-Hans/workflow.ts

@@ -204,6 +204,7 @@ const translation = {
       singleRun: {
         variable: '变量',
       },
+      sysQueryInUser: 'user message 中必须包含 sys.query',
     },
     knowledgeRetrieval: {
       queryVariable: '查询变量',