|
@@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
|
|
|
import { fetchDatasets } from '@/service/datasets'
|
|
|
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
|
|
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
|
|
-import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
|
+import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
|
|
|
|
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
@@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
const startNodeId = startNode?.id
|
|
|
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
|
|
|
|
|
+ const inputRef = useRef(inputs)
|
|
|
+
|
|
|
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
|
|
const newInputs = produce(s, (draft) => {
|
|
|
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
|
@@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
})
|
|
|
// not work in pass to draft...
|
|
|
doSetInputs(newInputs)
|
|
|
+ inputRef.current = newInputs
|
|
|
}, [doSetInputs])
|
|
|
|
|
|
- const inputRef = useRef(inputs)
|
|
|
- useEffect(() => {
|
|
|
- inputRef.current = inputs
|
|
|
- }, [inputs])
|
|
|
-
|
|
|
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
|
|
const newInputs = produce(inputs, (draft) => {
|
|
|
draft.query_variable_selector = newVar as ValueSelector
|
|
@@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
|
|
|
|
|
const {
|
|
|
+ modelList: rerankModelList,
|
|
|
defaultModel: rerankDefaultModel,
|
|
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
|
|
|
|
|
+ const {
|
|
|
+ currentModel: currentRerankModel,
|
|
|
+ } = useCurrentProviderAndModel(
|
|
|
+ rerankModelList,
|
|
|
+ rerankDefaultModel
|
|
|
+ ? {
|
|
|
+ ...rerankDefaultModel,
|
|
|
+ provider: rerankDefaultModel.provider.provider,
|
|
|
+ }
|
|
|
+ : undefined,
|
|
|
+ )
|
|
|
+
|
|
|
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
|
|
const newInputs = produce(inputRef.current, (draft) => {
|
|
|
if (!draft.single_retrieval_config) {
|
|
@@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
// set defaults models
|
|
|
useEffect(() => {
|
|
|
const inputs = inputRef.current
|
|
|
- if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
|
|
|
+ if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
|
|
|
return
|
|
|
|
|
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
|
@@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
|
|
draft.multiple_retrieval_config = {
|
|
|
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
|
@@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
reranking_model: multipleRetrievalConfig?.reranking_model,
|
|
|
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
|
|
weights: multipleRetrievalConfig?.weights,
|
|
|
+ reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
|
|
|
+ ? multipleRetrievalConfig.reranking_enable
|
|
|
+ : Boolean(currentRerankModel && rerankDefaultModel),
|
|
|
}
|
|
|
})
|
|
|
setInputs(newInput)
|
|
@@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
}, [])
|
|
|
|
|
|
useEffect(() => {
|
|
|
+ const inputs = inputRef.current
|
|
|
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
|
|
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
|
|
query_variable_selector = [startNodeId, 'sys.query']
|
|
|
|
|
|
- setInputs({
|
|
|
- ...inputs,
|
|
|
- query_variable_selector,
|
|
|
- })
|
|
|
+ setInputs(produce(inputs, (draft) => {
|
|
|
+ draft.query_variable_selector = query_variable_selector
|
|
|
+ }))
|
|
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
|
|
}, [])
|
|
|
|