Просмотр исходного кода

Fix/retrieval setting weight default value (#9622)

zxhlyh 6 месяцев назад
Родитель
Сommit
ff956cb546

+ 18 - 1
web/app/components/app/configuration/dataset-config/index.tsx

@@ -13,6 +13,11 @@ import ContextVar from './context-var'
 import ConfigContext from '@/context/debug-configuration'
 import { AppType } from '@/types/app'
 import type { DataSet } from '@/models/datasets'
+import {
+  getMultipleRetrievalConfig,
+} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
+import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 const Icon = (
   <svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -31,13 +36,25 @@ const DatasetConfig: FC = () => {
     setModelConfig,
     showSelectDataSet,
     isAgent,
+    datasetConfigs,
+    setDatasetConfigs,
   } = useContext(ConfigContext)
   const formattingChangedDispatcher = useFormattingChangedDispatcher()
 
   const hasData = dataSet.length > 0
 
+  const {
+    currentModel: currentRerankModel,
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
+
   const onRemove = (id: string) => {
-    setDataSet(dataSet.filter(item => item.id !== id))
+    const filteredDataSets = dataSet.filter(item => item.id !== id)
+    setDataSet(filteredDataSets)
+    const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel)
+    setDatasetConfigs({
+      ...(datasetConfigs as any),
+      ...retrievalConfig,
+    })
     formattingChangedDispatcher()
   }
 

+ 1 - 1
web/app/components/app/configuration/dataset-config/params-config/config-content.tsx

@@ -55,7 +55,7 @@ const ConfigContent: FC<Props> = ({
         retrieval_model: RETRIEVE_TYPE.multiWay,
       }, isInWorkflow)
     }
-  }, [type])
+  }, [type, datasetConfigs, isInWorkflow, onChange])
 
   const {
     modelList: rerankModelList,

+ 4 - 53
web/app/components/app/configuration/dataset-config/params-config/index.tsx

@@ -16,7 +16,6 @@ import type { DataSet } from '@/models/datasets'
 import type { DatasetConfigs } from '@/models/debug'
 import {
   getMultipleRetrievalConfig,
-  getSelectedDatasetsMode,
 } from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
 
 type ParamsConfigProps = {
@@ -37,57 +36,8 @@ const ParamsConfig = ({
   const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs)
 
   useEffect(() => {
-    const {
-      allEconomic,
-      allHighQuality,
-      allHighQualityFullTextSearch,
-      allHighQualityVectorSearch,
-      allExternal,
-      mixtureHighQualityAndEconomic,
-      inconsistentEmbeddingModel,
-      mixtureInternalAndExternal,
-    } = getSelectedDatasetsMode(selectedDatasets)
-
-    if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
-      setRerankSettingModalOpen(false)
-
-    if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal || (allExternal && selectedDatasets.length > 1))
-      setRerankSettingModalOpen(true)
-  }, [selectedDatasets])
-
-  useEffect(() => {
-    const {
-      allEconomic,
-      allInternal,
-      allExternal,
-    } = getSelectedDatasetsMode(selectedDatasets)
-    const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
-    let rerankEnable = restConfigs.reranking_enable
-
-    if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
-      rerankEnable = false
-
-    setTempDataSetConfigs({
-      ...getMultipleRetrievalConfig({
-        top_k: restConfigs.top_k,
-        score_threshold: restConfigs.score_threshold,
-        reranking_model: restConfigs.reranking_model && {
-          provider: restConfigs.reranking_model.reranking_provider_name,
-          model: restConfigs.reranking_model.reranking_model_name,
-        },
-        reranking_mode: restConfigs.reranking_mode,
-        weights: restConfigs.weights,
-        reranking_enable: rerankEnable,
-      }, selectedDatasets),
-      reranking_model: restConfigs.reranking_model && {
-        reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
-        reranking_model_name: restConfigs.reranking_model.reranking_model_name,
-      },
-      retrieval_model,
-      score_threshold_enabled,
-      datasets,
-    })
-  }, [selectedDatasets, datasetConfigs])
+    setTempDataSetConfigs(datasetConfigs)
+  }, [datasetConfigs])
 
   const {
     defaultModel: rerankDefaultModel,
@@ -135,7 +85,7 @@ const ParamsConfig = ({
       reranking_mode: restConfigs.reranking_mode,
       weights: restConfigs.weights,
       reranking_enable: restConfigs.reranking_enable,
-    }, selectedDatasets)
+    }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid)
 
     setTempDataSetConfigs({
       ...retrievalConfig,
@@ -180,6 +130,7 @@ const ParamsConfig = ({
 
             <div className='mt-6 flex justify-end'>
               <Button className='mr-2 flex-shrink-0' onClick={() => {
+                setTempDataSetConfigs(datasetConfigs)
                 setRerankSettingModalOpen(false)
               }}>{t('common.operation.cancel')}</Button>
               <Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>

+ 11 - 3
web/app/components/app/configuration/index.tsx

@@ -38,7 +38,7 @@ import ConfigContext from '@/context/debug-configuration'
 import Config from '@/app/components/app/configuration/config'
 import Debug from '@/app/components/app/configuration/debug'
 import Confirm from '@/app/components/base/confirm'
-import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 import { ToastContext } from '@/app/components/base/toast'
 import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
 import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
@@ -53,7 +53,10 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
 import Drawer from '@/app/components/base/drawer'
 import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
 import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
-import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import {
+  useModelListAndDefaultModelAndCurrentProviderAndModel,
+  useTextGenerationCurrentProviderAndModelAndModelList,
+} from '@/app/components/header/account-setting/model-provider-page/hooks'
 import { fetchCollectionList } from '@/service/tools'
 import { type Collection } from '@/app/components/tools/types'
 import { useStore as useAppStore } from '@/app/components/app/store'
@@ -217,6 +220,9 @@ const Configuration: FC = () => {
   const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false)
   const selectedIds = dataSets.map(item => item.id)
   const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
+  const {
+    currentModel: currentRerankModel,
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
   const handleSelect = (data: DataSet[]) => {
     if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
       hideSelectDataSet()
@@ -263,7 +269,7 @@ const Configuration: FC = () => {
       reranking_mode: restConfigs.reranking_mode,
       weights: restConfigs.weights,
       reranking_enable: restConfigs.reranking_enable,
-    }, newDatasets)
+    }, newDatasets, dataSets, !!currentRerankModel)
 
     setDatasetConfigs({
       ...retrievalConfig,
@@ -603,9 +609,11 @@ const Configuration: FC = () => {
 
         syncToPublishedConfig(config)
         setPublishedConfig(config)
+        const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel)
         setDatasetConfigs({
           retrieval_model: RETRIEVE_TYPE.multiWay,
           ...modelConfig.dataset_configs,
+          ...retrievalConfig,
         })
         setHasFetchedDetail(true)
       })

+ 6 - 6
web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts

@@ -163,7 +163,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       draft.retrieval_mode = newMode
       if (newMode === RETRIEVE_TYPE.multiWay) {
         const multipleRetrievalConfig = draft.multiple_retrieval_config
-        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
+        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
       }
       else {
         const hasSetModel = draft.single_retrieval_config?.model?.provider
@@ -180,14 +180,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       }
     })
     setInputs(newInputs)
-  }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
+  }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel])
 
   const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
     const newInputs = produce(inputs, (draft) => {
-      draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
+      draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
     })
     setInputs(newInputs)
-  }, [inputs, setInputs, selectedDatasets])
+  }, [inputs, setInputs, selectedDatasets, currentRerankModel])
 
   // datasets
   useEffect(() => {
@@ -231,7 +231,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
 
       if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
         const multipleRetrievalConfig = draft.multiple_retrieval_config
-        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
+        draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel)
       }
     })
     setInputs(newInputs)
@@ -243,7 +243,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       || (allExternal && newDatasets.length > 1)
     )
       setRerankModelOpen(true)
-  }, [inputs, setInputs, payload.retrieval_mode])
+  }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])
 
   const filterVar = useCallback((varPayload: Var) => {
     return varPayload.type === VarType.string

+ 46 - 3
web/app/components/workflow/nodes/knowledge-retrieval/utils.ts

@@ -1,4 +1,7 @@
-import { uniq } from 'lodash-es'
+import {
+  uniq,
+  xorBy,
+} from 'lodash-es'
 import type { MultipleRetrievalConfig } from './types'
 import type {
   DataSet,
@@ -15,7 +18,9 @@ export const checkNodeValid = () => {
   return true
 }
 
-export const getSelectedDatasetsMode = (datasets: DataSet[]) => {
+export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => {
+  if (datasets === null)
+    datasets = []
   let allHighQuality = true
   let allHighQualityVectorSearch = true
   let allHighQualityFullTextSearch = true
@@ -85,7 +90,14 @@ export const getSelectedDatasetsMode = (datasets: DataSet[]) => {
   } as SelectedDatasetsMode
 }
 
-export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[]) => {
+export const getMultipleRetrievalConfig = (
+  multipleRetrievalConfig: MultipleRetrievalConfig,
+  selectedDatasets: DataSet[],
+  originalDatasets: DataSet[],
+  isValidRerankModel?: boolean,
+) => {
+  const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
+
   const {
     allHighQuality,
     allHighQualityVectorSearch,
@@ -123,6 +135,37 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
     result.reranking_mode = RerankingModeEnum.WeightedScore
 
   if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
+    if (!isValidRerankModel)
+      result.reranking_mode = RerankingModeEnum.WeightedScore
+    else
+      result.reranking_mode = RerankingModeEnum.RerankingModel
+
+    result.weights = {
+      vector_setting: {
+        vector_weight: allHighQualityVectorSearch
+          ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
+          : allHighQualityFullTextSearch
+            ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
+            : DEFAULT_WEIGHTED_SCORE.other.semantic,
+        embedding_provider_name: selectedDatasets[0].embedding_model_provider,
+        embedding_model_name: selectedDatasets[0].embedding_model,
+      },
+      keyword_setting: {
+        keyword_weight: allHighQualityVectorSearch
+          ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
+          : allHighQualityFullTextSearch
+            ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
+            : DEFAULT_WEIGHTED_SCORE.other.keyword,
+      },
+    }
+  }
+
+  if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
+    if (!isValidRerankModel)
+      result.reranking_mode = RerankingModeEnum.WeightedScore
+    else
+      result.reranking_mode = RerankingModeEnum.RerankingModel
+
     result.weights = {
       vector_setting: {
         vector_weight: allHighQualityVectorSearch

+ 0 - 8
web/models/datasets.ts

@@ -566,14 +566,6 @@ export const DEFAULT_WEIGHTED_SCORE = {
     semantic: 0,
     keyword: 1.0,
   },
-  semanticFirst: {
-    semantic: 0.7,
-    keyword: 0.3,
-  },
-  keywordFirst: {
-    semantic: 0.3,
-    keyword: 0.7,
-  },
   other: {
     semantic: 0.7,
     keyword: 0.3,