浏览代码

Fix: rerank switch and validation before run (#9416)

Yi Xiao 6 月之前
父节点
当前提交
8a1f106c72

+ 33 - 15
web/app/components/app/configuration/dataset-config/params-config/config-content.tsx

@@ -63,7 +63,7 @@ const ConfigContent: FC<Props> = ({
   } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
 
   const {
-    currentModel,
+    currentModel: currentRerankModel,
   } = useCurrentProviderAndModel(
     rerankModelList,
     rerankDefaultModel
@@ -74,11 +74,6 @@ const ConfigContent: FC<Props> = ({
       : undefined,
   )
 
-  const handleDisabledSwitchClick = useCallback(() => {
-    if (!currentModel)
-      Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
-  }, [currentModel, rerankDefaultModel, t])
-
   const rerankModel = (() => {
     if (datasetConfigs.reranking_model?.reranking_provider_name) {
       return {
@@ -164,12 +159,33 @@ const ConfigContent: FC<Props> = ({
   const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
   const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
 
+  const canManuallyToggleRerank = useMemo(() => {
+    return !(
+      (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
+      || selectedDatasetsMode.allExternal
+    )
+  }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
+
   const showRerankModel = useMemo(() => {
-    if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic)
+    if (!canManuallyToggleRerank)
       return false
 
-    return true
-  }, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic])
+    return datasetConfigs.reranking_enable
+  }, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
+
+  const handleDisabledSwitchClick = useCallback(() => {
+    if (!currentRerankModel && !showRerankModel)
+      Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
+  }, [currentRerankModel, showRerankModel, t])
+
+  useEffect(() => {
+    if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
+      onChange({
+        ...datasetConfigs,
+        reranking_enable: showRerankModel,
+      })
+    }
+  }, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
 
   return (
     <div>
@@ -256,13 +272,15 @@ const ConfigContent: FC<Props> = ({
                       >
                         <Switch
                           size='md'
-                          defaultValue={currentModel ? showRerankModel : false}
-                          disabled={!currentModel}
+                          defaultValue={showRerankModel}
+                          disabled={!currentRerankModel || !canManuallyToggleRerank}
                           onChange={(v) => {
-                            onChange({
-                              ...datasetConfigs,
-                              reranking_enable: v,
-                            })
+                            if (canManuallyToggleRerank) {
+                              onChange({
+                                ...datasetConfigs,
+                                reranking_enable: v,
+                              })
+                            }
                           }}
                         />
                       </div>

+ 2 - 1
web/app/components/app/configuration/dataset-config/params-config/index.tsx

@@ -42,6 +42,7 @@ const ParamsConfig = ({
       allHighQuality,
       allHighQualityFullTextSearch,
       allHighQualityVectorSearch,
+      allInternal,
       allExternal,
       mixtureHighQualityAndEconomic,
       inconsistentEmbeddingModel,
@@ -50,7 +51,7 @@ const ParamsConfig = ({
     const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
     let rerankEnable = restConfigs.reranking_enable
 
-    if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal)
+    if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
       rerankEnable = false
 
     if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))

+ 0 - 55
web/app/components/workflow/hooks/use-workflow-start-run.tsx

@@ -1,25 +1,17 @@
 import { useCallback } from 'react'
 import { useStoreApi } from 'reactflow'
-import { useTranslation } from 'react-i18next'
 import { useWorkflowStore } from '../store'
 import {
   BlockEnum,
   WorkflowRunningStatus,
 } from '../types'
-import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
-import type { Node } from '../types'
-import { useWorkflow } from './use-workflow'
 import {
   useIsChatMode,
   useNodesSyncDraft,
   useWorkflowInteractions,
   useWorkflowRun,
 } from './index'
-import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
-import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
 import { useFeaturesStore } from '@/app/components/base/features/hooks'
-import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
-import Toast from '@/app/components/base/toast'
 
 export const useWorkflowStartRun = () => {
   const store = useStoreApi()
@@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => {
   const isChatMode = useIsChatMode()
   const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
   const { handleRun } = useWorkflowRun()
-  const { isFromStartNode } = useWorkflow()
   const { doSyncWorkflowDraft } = useNodesSyncDraft()
-  const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
-  const { t } = useTranslation()
-  const {
-    modelList: rerankModelList,
-    defaultModel: rerankDefaultModel,
-  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
-
-  const {
-    currentModel,
-  } = useCurrentProviderAndModel(
-    rerankModelList,
-    rerankDefaultModel
-      ? {
-        ...rerankDefaultModel,
-        provider: rerankDefaultModel.provider.provider,
-      }
-      : undefined,
-  )
 
   const handleWorkflowStartRunInWorkflow = useCallback(async () => {
     const {
@@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => {
     const { getNodes } = store.getState()
     const nodes = getNodes()
     const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
-    const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
-      node.data.type === BlockEnum.KnowledgeRetrieval,
-    )
     const startVariables = startNode?.data.variables || []
     const fileSettings = featuresStore!.getState().features.file
     const {
@@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => {
       setShowEnvPanel,
     } = workflowStore.getState()
 
-    if (knowledgeRetrievalNodes.length > 0) {
-      for (const node of knowledgeRetrievalNodes) {
-        if (isFromStartNode(node.id)) {
-          const res = checkKnowledgeRetrievalValid(node.data, t)
-          if (!res.isValid || !currentModel || !rerankDefaultModel) {
-            const errorMessage = res.errorMessage
-            if (errorMessage) {
-              Toast.notify({
-                type: 'error',
-                message: errorMessage,
-              })
-              return false
-            }
-            else {
-              Toast.notify({
-                type: 'error',
-                message: t('appDebug.datasetConfig.rerankModelRequired'),
-              })
-              return false
-            }
-          }
-        }
-      }
-    }
-
     setShowEnvPanel(false)
 
     if (showDebugAndPreviewPanel) {

+ 25 - 12
web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts

@@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
 import { fetchDatasets } from '@/service/datasets'
 import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
 import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
-import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
 import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
@@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
   const startNodeId = startNode?.id
   const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
 
+  const inputRef = useRef(inputs)
+
   const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
     const newInputs = produce(s, (draft) => {
       if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
@@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
     })
     // not work in pass to draft...
     doSetInputs(newInputs)
+    inputRef.current = newInputs
   }, [doSetInputs])
 
-  const inputRef = useRef(inputs)
-  useEffect(() => {
-    inputRef.current = inputs
-  }, [inputs])
-
   const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
     const newInputs = produce(inputs, (draft) => {
       draft.query_variable_selector = newVar as ValueSelector
@@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
   } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
 
   const {
+    modelList: rerankModelList,
     defaultModel: rerankDefaultModel,
   } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
 
+  const {
+    currentModel: currentRerankModel,
+  } = useCurrentProviderAndModel(
+    rerankModelList,
+    rerankDefaultModel
+      ? {
+        ...rerankDefaultModel,
+        provider: rerankDefaultModel.provider.provider,
+      }
+      : undefined,
+  )
+
   const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
     const newInputs = produce(inputRef.current, (draft) => {
       if (!draft.single_retrieval_config) {
@@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
   // set defaults models
   useEffect(() => {
     const inputs = inputRef.current
-    if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
+    if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
       return
 
     if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
@@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
           }
         }
       }
-
       const multipleRetrievalConfig = draft.multiple_retrieval_config
       draft.multiple_retrieval_config = {
         top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
@@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
         reranking_model: multipleRetrievalConfig?.reranking_model,
         reranking_mode: multipleRetrievalConfig?.reranking_mode,
         weights: multipleRetrievalConfig?.weights,
+        reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
+          ? multipleRetrievalConfig.reranking_enable
+          : Boolean(currentRerankModel && rerankDefaultModel),
       }
     })
     setInputs(newInput)
@@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
   }, [])
 
   useEffect(() => {
+    const inputs = inputRef.current
     let query_variable_selector: ValueSelector = inputs.query_variable_selector
     if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
       query_variable_selector = [startNodeId, 'sys.query']
 
-    setInputs({
-      ...inputs,
-      query_variable_selector,
-    })
+    setInputs(produce(inputs, (draft) => {
+      draft.query_variable_selector = query_variable_selector
+    }))
   // eslint-disable-next-line react-hooks/exhaustive-deps
   }, [])
 

+ 1 - 1
web/app/components/workflow/nodes/knowledge-retrieval/utils.ts

@@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
     reranking_mode,
     reranking_model,
     weights,
-    reranking_enable: allEconomic ? reranking_enable : true,
+    reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
   }
 
   if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)