Переглянути джерело

fix: retrieval setting validate (#10454)

zxhlyh 5 місяців тому
батько
коміт
e4d175780e

+ 5 - 1
web/app/components/app/configuration/dataset-config/index.tsx

@@ -47,12 +47,16 @@ const DatasetConfig: FC = () => {
 
   const {
     currentModel: currentRerankModel,
+    currentProvider: currentRerankProvider,
   } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
 
   const onRemove = (id: string) => {
     const filteredDataSets = dataSet.filter(item => item.id !== id)
     setDataSet(filteredDataSets)
-    const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel)
+    const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, {
+      provider: currentRerankProvider?.provider,
+      model: currentRerankModel?.model,
+    })
     setDatasetConfigs({
       ...(datasetConfigs as any),
       ...retrievalConfig,

+ 1 - 1
web/app/components/app/configuration/dataset-config/params-config/config-content.tsx

@@ -172,7 +172,7 @@ const ConfigContent: FC<Props> = ({
       return false
 
     return datasetConfigs.reranking_enable
-  }, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
+  }, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
 
   const handleDisabledSwitchClick = useCallback(() => {
     if (!currentRerankModel && !showRerankModel)

+ 5 - 1
web/app/components/app/configuration/dataset-config/params-config/index.tsx

@@ -43,6 +43,7 @@ const ParamsConfig = ({
   const {
     defaultModel: rerankDefaultModel,
     currentModel: isRerankDefaultModelValid,
+    currentProvider: rerankDefaultProvider,
   } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
 
   const isValid = () => {
@@ -91,7 +92,10 @@ const ParamsConfig = ({
       reranking_mode: restConfigs.reranking_mode,
       weights: restConfigs.weights,
       reranking_enable: restConfigs.reranking_enable,
-    }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid)
+    }, selectedDatasets, selectedDatasets, {
+      provider: rerankDefaultProvider?.provider,
+      model: isRerankDefaultModelValid?.model,
+    })
 
     setTempDataSetConfigs({
       ...retrievalConfig,

+ 9 - 2
web/app/components/app/configuration/index.tsx

@@ -226,6 +226,7 @@ const Configuration: FC = () => {
   const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
   const {
     currentModel: currentRerankModel,
+    currentProvider: currentRerankProvider,
   } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
   const handleSelect = (data: DataSet[]) => {
     if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
@@ -279,7 +280,10 @@ const Configuration: FC = () => {
       reranking_mode: restConfigs.reranking_mode,
       weights: restConfigs.weights,
       reranking_enable: restConfigs.reranking_enable,
-    }, newDatasets, dataSets, !!currentRerankModel)
+    }, newDatasets, dataSets, {
+      provider: currentRerankProvider?.provider,
+      model: currentRerankModel?.model,
+    })
 
     setDatasetConfigs({
       ...retrievalConfig,
@@ -620,7 +624,10 @@ const Configuration: FC = () => {
 
         syncToPublishedConfig(config)
         setPublishedConfig(config)
-        const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel)
+        const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, {
+          provider: currentRerankProvider?.provider,
+          model: currentRerankModel?.model,
+        })
         setDatasetConfigs({
           retrieval_model: RETRIEVE_TYPE.multiWay,
           ...modelConfig.dataset_configs,

+ 9 - 4
web/app/components/workflow/nodes/knowledge-retrieval/default.ts

@@ -1,7 +1,7 @@
 import { BlockEnum } from '../../types'
 import type { NodeDefault } from '../../types'
 import type { KnowledgeRetrievalNodeType } from './types'
-import { RerankingModeEnum } from '@/models/datasets'
+import { checkoutRerankModelConfigedInRetrievalSettings } from './utils'
 import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
 import { DATASET_DEFAULT } from '@/config'
 import { RETRIEVE_TYPE } from '@/types/app'
@@ -36,12 +36,17 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
     if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
       errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
 
-    if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider && payload.multiple_retrieval_config?.reranking_enable)
-      errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
-
     if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
       errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
 
+    const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
+    if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
+      const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config)
+
+      if (!errorMessages && !checked)
+        errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
+    }
+
     return {
       isValid: !errorMessages,
       errorMessage: errorMessages,

+ 2 - 0
web/app/components/workflow/nodes/knowledge-retrieval/types.ts

@@ -1,6 +1,7 @@
 import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
 import type { RETRIEVE_TYPE } from '@/types/app'
 import type {
+  DataSet,
   RerankingModeEnum,
 } from '@/models/datasets'
 
@@ -35,4 +36,5 @@ export type KnowledgeRetrievalNodeType = CommonNodeType & {
   retrieval_mode: RETRIEVE_TYPE
   multiple_retrieval_config?: MultipleRetrievalConfig
   single_retrieval_config?: SingleRetrievalConfig
+  _datasets?: DataSet[]
 }

+ 18 - 6
web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts

@@ -67,6 +67,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
 
   const {
     currentModel: currentRerankModel,
+    currentProvider: currentRerankProvider,
   } = useCurrentProviderAndModel(
     rerankModelList,
     rerankDefaultModel
@@ -163,7 +164,10 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       draft.retrieval_mode = newMode
       if (newMode === RETRIEVE_TYPE.multiWay) {
         const multipleRetrievalConfig = draft.multiple_retrieval_config
-        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
+        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, {
+          provider: currentRerankProvider?.provider,
+          model: currentRerankModel?.model,
+        })
       }
       else {
         const hasSetModel = draft.single_retrieval_config?.model?.provider
@@ -180,14 +184,17 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       }
     })
     setInputs(newInputs)
-  }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel])
+  }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
 
   const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
     const newInputs = produce(inputs, (draft) => {
-      draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
+      draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
+        provider: currentRerankProvider?.provider,
+        model: currentRerankModel?.model,
+      })
     })
     setInputs(newInputs)
-  }, [inputs, setInputs, selectedDatasets, currentRerankModel])
+  }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
 
   // datasets
   useEffect(() => {
@@ -200,6 +207,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       }
       const newInputs = produce(inputs, (draft) => {
         draft.dataset_ids = datasetIds
+        draft._datasets = selectedDatasets
       })
       setInputs(newInputs)
     })()
@@ -228,10 +236,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
     } = getSelectedDatasetsMode(newDatasets)
     const newInputs = produce(inputs, (draft) => {
       draft.dataset_ids = newDatasets.map(d => d.id)
+      draft._datasets = newDatasets
 
       if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
         const multipleRetrievalConfig = draft.multiple_retrieval_config
-        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel)
+        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
+          provider: currentRerankProvider?.provider,
+          model: currentRerankModel?.model,
+        })
       }
     })
     setInputs(newInputs)
@@ -243,7 +255,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       || allExternal
     )
       setRerankModelOpen(true)
-  }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])
+  }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
 
   const filterVar = useCallback((varPayload: Var) => {
     return varPayload.type === VarType.string

+ 80 - 33
web/app/components/workflow/nodes/knowledge-retrieval/utils.ts

@@ -94,9 +94,10 @@ export const getMultipleRetrievalConfig = (
   multipleRetrievalConfig: MultipleRetrievalConfig,
   selectedDatasets: DataSet[],
   originalDatasets: DataSet[],
-  isValidRerankModel?: boolean,
+  validRerankModel?: { provider?: string; model?: string },
 ) => {
   const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
+  const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model
 
   const {
     allHighQuality,
@@ -128,18 +129,10 @@ export const getMultipleRetrievalConfig = (
     reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
   }
 
-  if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
-    result.reranking_mode = RerankingModeEnum.RerankingModel
-
-  if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal)
-    result.reranking_mode = RerankingModeEnum.WeightedScore
-
-  if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
-    if (!isValidRerankModel)
-      result.reranking_mode = RerankingModeEnum.WeightedScore
-    else
-      result.reranking_mode = RerankingModeEnum.RerankingModel
+  if (!rerankModelIsValid)
+    result.reranking_model = undefined
 
+  const setDefaultWeights = () => {
     result.weights = {
       vector_setting: {
         vector_weight: allHighQualityVectorSearch
@@ -160,31 +153,85 @@ export const getMultipleRetrievalConfig = (
     }
   }
 
-  if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
-    if (!isValidRerankModel)
-      result.reranking_mode = RerankingModeEnum.WeightedScore
-    else
+  if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) {
+    result.reranking_mode = RerankingModeEnum.RerankingModel
+
+    if (rerankModelIsValid) {
       result.reranking_mode = RerankingModeEnum.RerankingModel
+      result.reranking_model = {
+        provider: validRerankModel?.provider || '',
+        model: validRerankModel?.model || '',
+      }
+    }
+    else {
+      result.reranking_model = undefined
+    }
+  }
 
-    result.weights = {
-      vector_setting: {
-        vector_weight: allHighQualityVectorSearch
-          ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
-          : allHighQualityFullTextSearch
-            ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
-            : DEFAULT_WEIGHTED_SCORE.other.semantic,
-        embedding_provider_name: selectedDatasets[0].embedding_model_provider,
-        embedding_model_name: selectedDatasets[0].embedding_model,
-      },
-      keyword_setting: {
-        keyword_weight: allHighQualityVectorSearch
-          ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
-          : allHighQualityFullTextSearch
-            ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
-            : DEFAULT_WEIGHTED_SCORE.other.keyword,
-      },
+  if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
+    if (!reranking_mode) {
+      if (validRerankModel?.provider && validRerankModel?.model) {
+        result.reranking_mode = RerankingModeEnum.RerankingModel
+        result.reranking_model = {
+          provider: validRerankModel.provider,
+          model: validRerankModel.model,
+        }
+      }
+      else {
+        result.reranking_mode = RerankingModeEnum.WeightedScore
+        setDefaultWeights()
+      }
+    }
+
+    if (reranking_mode === RerankingModeEnum.WeightedScore && !weights)
+      setDefaultWeights()
+
+    if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) {
+      if (rerankModelIsValid) {
+        result.reranking_mode = RerankingModeEnum.RerankingModel
+        result.reranking_model = {
+          provider: validRerankModel.provider || '',
+          model: validRerankModel.model || '',
+        }
+      }
+      else {
+        setDefaultWeights()
+      }
+    }
+
+    if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
+      result.reranking_mode = RerankingModeEnum.WeightedScore
+      setDefaultWeights()
     }
   }
 
   return result
 }
+
+export const checkoutRerankModelConfigedInRetrievalSettings = (
+  datasets: DataSet[],
+  multipleRetrievalConfig?: MultipleRetrievalConfig,
+) => {
+  if (!multipleRetrievalConfig)
+    return true
+
+  const {
+    allEconomic,
+    allExternal,
+  } = getSelectedDatasetsMode(datasets)
+
+  const {
+    reranking_enable,
+    reranking_mode,
+    reranking_model,
+  } = multipleRetrievalConfig
+
+  if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) {
+    if ((allEconomic || allExternal) && !reranking_enable)
+      return true
+
+    return false
+  }
+
+  return true
+}