|
@@ -12,6 +12,7 @@ import { RETRIEVE_TYPE } from '@/types/app'
|
|
|
import Toast from '@/app/components/base/toast'
|
|
|
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
|
+import { RerankingModeEnum } from '@/models/datasets'
|
|
|
import type { DataSet } from '@/models/datasets'
|
|
|
import type { DatasetConfigs } from '@/models/debug'
|
|
|
import {
|
|
@@ -47,7 +48,10 @@ const ParamsConfig = ({
|
|
|
const isValid = () => {
|
|
|
let errMsg = ''
|
|
|
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
|
|
|
- if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (rerankDefaultModel && !isRerankDefaultModelValid))
|
|
|
+ if (tempDataSetConfigs.reranking_enable
|
|
|
+ && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
|
|
|
+ && !isRerankDefaultModelValid
|
|
|
+ )
|
|
|
errMsg = t('appDebug.datasetConfig.rerankModelRequired')
|
|
|
}
|
|
|
if (errMsg) {
|
|
@@ -62,7 +66,9 @@ const ParamsConfig = ({
|
|
|
if (!isValid())
|
|
|
return
|
|
|
const config = { ...tempDataSetConfigs }
|
|
|
- if (config.retrieval_model === RETRIEVE_TYPE.multiWay && !config.reranking_model) {
|
|
|
+ if (config.retrieval_model === RETRIEVE_TYPE.multiWay
|
|
|
+ && config.reranking_mode === RerankingModeEnum.RerankingModel
|
|
|
+ && !config.reranking_model) {
|
|
|
config.reranking_model = {
|
|
|
reranking_provider_name: rerankDefaultModel?.provider?.provider,
|
|
|
reranking_model_name: rerankDefaultModel?.model,
|