فهرست منبع

feat: add datasets detail context and provider for improved data vali… (#16451)

Wu Tianwei 1 ماه پیش
والد
کامیت
9701b573e0

+ 53 - 0
web/app/components/workflow/datasets-detail-store/provider.tsx

@@ -0,0 +1,53 @@
+import type { FC } from 'react'
+import { createContext, useCallback, useEffect, useRef } from 'react'
+import { createDatasetsDetailStore } from './store'
+import type { CommonNodeType, Node } from '../types'
+import { BlockEnum } from '../types'
+import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
+import { fetchDatasets } from '@/service/datasets'
+
+type DatasetsDetailStoreApi = ReturnType<typeof createDatasetsDetailStore>
+
+type DatasetsDetailContextType = DatasetsDetailStoreApi | undefined
+
+export const DatasetsDetailContext = createContext<DatasetsDetailContextType>(undefined)
+
+type DatasetsDetailProviderProps = {
+  nodes: Node[]
+  children: React.ReactNode
+}
+
+const DatasetsDetailProvider: FC<DatasetsDetailProviderProps> = ({
+  nodes,
+  children,
+}) => {
+  const storeRef = useRef<DatasetsDetailStoreApi>()
+
+  if (!storeRef.current)
+    storeRef.current = createDatasetsDetailStore()
+
+  const updateDatasetsDetail = useCallback(async (datasetIds: string[]) => {
+    const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
+    if (datasetsDetail && datasetsDetail.length > 0)
+      storeRef.current!.getState().updateDatasetsDetail(datasetsDetail)
+  }, [])
+
+  useEffect(() => {
+    if (!storeRef.current) return
+    const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
+    const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
+      return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
+    }, [])
+    if (allDatasetIds.length === 0) return
+    updateDatasetsDetail(allDatasetIds)
+  // eslint-disable-next-line react-hooks/exhaustive-deps
+  }, [])
+
+  return (
+    <DatasetsDetailContext.Provider value={storeRef.current!}>
+      {children}
+    </DatasetsDetailContext.Provider>
+  )
+}
+
+export default DatasetsDetailProvider

+ 38 - 0
web/app/components/workflow/datasets-detail-store/store.ts

@@ -0,0 +1,38 @@
+import { useContext } from 'react'
+import { createStore, useStore } from 'zustand'
+import type { DataSet } from '@/models/datasets'
+import { DatasetsDetailContext } from './provider'
+import produce from 'immer'
+
+type DatasetsDetailStore = {
+  datasetsDetail: Record<string, DataSet>
+  updateDatasetsDetail: (datasetsDetail: DataSet[]) => void
+}
+
+export const createDatasetsDetailStore = () => {
+  return createStore<DatasetsDetailStore>((set, get) => ({
+    datasetsDetail: {},
+    updateDatasetsDetail: (datasets: DataSet[]) => {
+      const oldDatasetsDetail = get().datasetsDetail
+      const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => {
+        acc[dataset.id] = dataset
+        return acc
+      }, {})
+      // Merge new datasets detail into old one
+      const newDatasetsDetail = produce(oldDatasetsDetail, (draft) => {
+        Object.entries(datasetsDetail).forEach(([key, value]) => {
+          draft[key] = value
+        })
+      })
+      set({ datasetsDetail: newDatasetsDetail })
+    },
+  }))
+}
+
+export const useDatasetsDetailStore = <T>(selector: (state: DatasetsDetailStore) => T): T => {
+  const store = useContext(DatasetsDetailContext)
+  if (!store)
+    throw new Error('Missing DatasetsDetailContext.Provider in the tree')
+
+  return useStore(store, selector)
+}

+ 1 - 1
web/app/components/workflow/header/index.tsx

@@ -160,7 +160,7 @@ const Header: FC = () => {
   const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!)
 
   const onPublish = useCallback(async (params?: PublishWorkflowParams) => {
-    if (handleCheckBeforePublish()) {
+    if (await handleCheckBeforePublish()) {
       const res = await publishWorkflow({
         title: params?.title || '',
         releaseNotes: params?.releaseNotes || '',

+ 72 - 5
web/app/components/workflow/hooks/use-checklist.ts

@@ -1,10 +1,12 @@
 import {
   useCallback,
   useMemo,
+  useRef,
 } from 'react'
 import { useTranslation } from 'react-i18next'
 import { useStoreApi } from 'reactflow'
 import type {
+  CommonNodeType,
   Edge,
   Node,
 } from '../types'
@@ -27,6 +29,10 @@ import { useGetLanguage } from '@/context/i18n'
 import type { AgentNodeType } from '../nodes/agent/types'
 import { useStrategyProviders } from '@/service/use-strategy'
 import { canFindTool } from '@/utils'
+import { useDatasetsDetailStore } from '../datasets-detail-store/store'
+import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
+import type { DataSet } from '@/models/datasets'
+import { fetchDatasets } from '@/service/datasets'
 
 export const useChecklist = (nodes: Node[], edges: Edge[]) => {
   const { t } = useTranslation()
@@ -37,6 +43,24 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
   const customTools = useStore(s => s.customTools)
   const workflowTools = useStore(s => s.workflowTools)
   const { data: strategyProviders } = useStrategyProviders()
+  const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
+
+  const getCheckData = useCallback((data: CommonNodeType<{}>) => {
+    let checkData = data
+    if (data.type === BlockEnum.KnowledgeRetrieval) {
+      const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
+      const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => {
+        if (datasetsDetail[id])
+          acc.push(datasetsDetail[id])
+        return acc
+      }, [])
+      checkData = {
+        ...data,
+        _datasets,
+      } as CommonNodeType<KnowledgeRetrievalNodeType>
+    }
+    return checkData
+  }, [datasetsDetail])
 
   const needWarningNodes = useMemo(() => {
     const list = []
@@ -75,7 +99,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
       }
 
       if (node.type === CUSTOM_NODE) {
-        const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid)
+        const checkData = getCheckData(node.data)
+        const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid)
 
         if (errorMessage || !validNodes.find(n => n.id === node.id)) {
           list.push({
@@ -109,7 +134,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
     }
 
     return list
-  }, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders])
+  }, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders, getCheckData])
 
   return needWarningNodes
 }
@@ -125,8 +150,31 @@ export const useChecklistBeforePublish = () => {
   const store = useStoreApi()
   const nodesExtraData = useNodesExtraData()
   const { data: strategyProviders } = useStrategyProviders()
+  const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
+  const updateTime = useRef(0)
+
+  const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => {
+    let checkData = data
+    if (data.type === BlockEnum.KnowledgeRetrieval) {
+      const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
+      const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => {
+        acc[dataset.id] = dataset
+        return acc
+      }, {})
+      const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => {
+        if (datasetsDetail[id])
+          acc.push(datasetsDetail[id])
+        return acc
+      }, [])
+      checkData = {
+        ...data,
+        _datasets,
+      } as CommonNodeType<KnowledgeRetrievalNodeType>
+    }
+    return checkData
+  }, [])
 
-  const handleCheckBeforePublish = useCallback(() => {
+  const handleCheckBeforePublish = useCallback(async () => {
     const {
       getNodes,
       edges,
@@ -141,6 +189,24 @@ export const useChecklistBeforePublish = () => {
       notify({ type: 'error', message: t('workflow.common.maxTreeDepth', { depth: MAX_TREE_DEPTH }) })
       return false
     }
+    // Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed
+    const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
+    const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
+      return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
+    }, [])
+    let datasets: DataSet[] = []
+    if (allDatasetIds.length > 0) {
+      updateTime.current = updateTime.current + 1
+      const currUpdateTime = updateTime.current
+      const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
+      if (datasetsDetail && datasetsDetail.length > 0) {
+        // avoid old data to overwrite the new data
+        if (currUpdateTime < updateTime.current)
+          return false
+        datasets = datasetsDetail
+        updateDatasetsDetail(datasetsDetail)
+      }
+    }
 
     for (let i = 0; i < nodes.length; i++) {
       const node = nodes[i]
@@ -161,7 +227,8 @@ export const useChecklistBeforePublish = () => {
         }
       }
 
-      const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(node.data, t, moreDataForCheckValid)
+      const checkData = getCheckData(node.data, datasets)
+      const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)
 
       if (errorMessage) {
         notify({ type: 'error', message: `[${node.data.title}] ${errorMessage}` })
@@ -185,7 +252,7 @@ export const useChecklistBeforePublish = () => {
     }
 
     return true
-  }, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders])
+  }, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders, updateDatasetsDetail, getCheckData])
 
   return {
     handleCheckBeforePublish,

+ 8 - 5
web/app/components/workflow/index.tsx

@@ -99,6 +99,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter'
 import Confirm from '@/app/components/base/confirm'
 import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
 import { fetchFileUploadConfig } from '@/service/common'
+import DatasetsDetailProvider from './datasets-detail-store/provider'
 
 const nodeTypes = {
   [CUSTOM_NODE]: CustomNode,
@@ -448,11 +449,13 @@ const WorkflowWrap = memo(() => {
         nodes={nodesData}
         edges={edgesData} >
         <FeaturesProvider features={initialFeatures}>
-          <Workflow
-            nodes={nodesData}
-            edges={edgesData}
-            viewport={data?.graph.viewport}
-          />
+          <DatasetsDetailProvider nodes={nodesData}>
+            <Workflow
+              nodes={nodesData}
+              edges={edgesData}
+              viewport={data?.graph.viewport}
+            />
+          </DatasetsDetailProvider>
         </FeaturesProvider>
       </WorkflowHistoryProvider>
     </ReactFlowProvider>

+ 16 - 19
web/app/components/workflow/nodes/knowledge-retrieval/node.tsx

@@ -1,33 +1,30 @@
-import { type FC, useEffect, useRef, useState } from 'react'
+import { type FC, useEffect, useState } from 'react'
 import React from 'react'
 import type { KnowledgeRetrievalNodeType } from './types'
 import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
 import type { NodeProps } from '@/app/components/workflow/types'
-import { fetchDatasets } from '@/service/datasets'
 import type { DataSet } from '@/models/datasets'
+import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
 
 const Node: FC<NodeProps<KnowledgeRetrievalNodeType>> = ({
   data,
 }) => {
   const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
-  const updateTime = useRef(0)
-  useEffect(() => {
-    (async () => {
-      updateTime.current = updateTime.current + 1
-      const currUpdateTime = updateTime.current
+  const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
 
-      if (data.dataset_ids?.length > 0) {
-        const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } })
-        //  avoid old data overwrite new data
-        if (currUpdateTime < updateTime.current)
-          return
-        setSelectedDatasets(dataSetsWithDetail)
-      }
-      else {
-        setSelectedDatasets([])
-      }
-    })()
-  }, [data.dataset_ids])
+  useEffect(() => {
+    if (data.dataset_ids?.length > 0) {
+      const dataSetsWithDetail = data.dataset_ids.reduce<DataSet[]>((acc, id) => {
+        if (datasetsDetail[id])
+          acc.push(datasetsDetail[id])
+        return acc
+      }, [])
+      setSelectedDatasets(dataSetsWithDetail)
+    }
+    else {
+      setSelectedDatasets([])
+    }
+  }, [data.dataset_ids, datasetsDetail])
 
   if (!selectedDatasets.length)
     return null

+ 4 - 5
web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts

@@ -41,6 +41,7 @@ import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-s
 import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
 import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list'
+import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
 
 const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
   const { nodesReadOnly: readOnly } = useNodesReadOnly()
@@ -49,6 +50,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
   const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
   const startNodeId = startNode?.id
   const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
+  const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
 
   const inputRef = useRef(inputs)
 
@@ -218,15 +220,12 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
     (async () => {
       const inputs = inputRef.current
       const datasetIds = inputs.dataset_ids
-      let _datasets = selectedDatasets
       if (datasetIds?.length > 0) {
         const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any })
-        _datasets = dataSetsWithDetail
         setSelectedDatasets(dataSetsWithDetail)
       }
       const newInputs = produce(inputs, (draft) => {
         draft.dataset_ids = datasetIds
-        draft._datasets = _datasets
       })
       setInputs(newInputs)
       setSelectedDatasetsLoaded(true)
@@ -256,7 +255,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
     } = getSelectedDatasetsMode(newDatasets)
     const newInputs = produce(inputs, (draft) => {
       draft.dataset_ids = newDatasets.map(d => d.id)
-      draft._datasets = newDatasets
 
       if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
         const multipleRetrievalConfig = draft.multiple_retrieval_config
@@ -266,6 +264,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
         })
       }
     })
+    updateDatasetsDetail(newDatasets)
     setInputs(newInputs)
     setSelectedDatasets(newDatasets)
 
@@ -275,7 +274,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
       || allExternal
     )
       setRerankModelOpen(true)
-  }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
+  }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, updateDatasetsDetail])
 
   const filterVar = useCallback((varPayload: Var) => {
     return varPayload.type === VarType.string