Просмотр исходного кода

Feat: rerank model verification in front end (#9271)

Yi Xiao 6 месяцев назад
Родитель
Сommit
793205afc5

+ 37 - 12
web/app/components/app/configuration/dataset-config/params-config/config-content.tsx

@@ -1,6 +1,6 @@
 'use client'
 
-import { memo, useEffect, useMemo } from 'react'
+import { memo, useCallback, useEffect, useMemo } from 'react'
 import type { FC } from 'react'
 import { useTranslation } from 'react-i18next'
 import WeightedScore from './weighted-score'
@@ -11,7 +11,7 @@ import type {
   DatasetConfigs,
 } from '@/models/debug'
 import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
-import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
 import type { ModelConfig } from '@/app/components/workflow/types'
 import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
 import Tooltip from '@/app/components/base/tooltip'
@@ -23,6 +23,7 @@ import { RerankingModeEnum } from '@/models/datasets'
 import cn from '@/utils/classnames'
 import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks'
 import Switch from '@/app/components/base/switch'
+import Toast from '@/app/components/base/toast'
 
 type Props = {
   datasetConfigs: DatasetConfigs
@@ -60,6 +61,24 @@ const ConfigContent: FC<Props> = ({
     modelList: rerankModelList,
     defaultModel: rerankDefaultModel,
   } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
+
+  const {
+    currentModel,
+  } = useCurrentProviderAndModel(
+    rerankModelList,
+    rerankDefaultModel
+      ? {
+        ...rerankDefaultModel,
+        provider: rerankDefaultModel.provider.provider,
+      }
+      : undefined,
+  )
+
+  const handleDisabledSwitchClick = useCallback(() => {
+    if (!currentModel)
+      Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
+  }, [currentModel, rerankDefaultModel, t])
+
   const rerankModel = (() => {
     if (datasetConfigs.reranking_model?.reranking_provider_name) {
       return {
@@ -231,16 +250,22 @@ const ConfigContent: FC<Props> = ({
                 <div className='flex items-center'>
                   {
                     selectedDatasetsMode.allEconomic && (
-                      <Switch
-                        size='md'
-                        defaultValue={showRerankModel}
-                        onChange={(v) => {
-                          onChange({
-                            ...datasetConfigs,
-                            reranking_enable: v,
-                          })
-                        }}
-                      />
+                      <div
+                        className='flex items-center'
+                        onClick={handleDisabledSwitchClick}
+                      >
+                        <Switch
+                          size='md'
+                          defaultValue={currentModel ? showRerankModel : false}
+                          disabled={!currentModel}
+                          onChange={(v) => {
+                            onChange({
+                              ...datasetConfigs,
+                              reranking_enable: v,
+                            })
+                          }}
+                        />
+                      </div>
                     )
                   }
                   <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>

+ 37 - 12
web/app/components/datasets/common/retrieval-param-config/index.tsx

@@ -1,6 +1,6 @@
 'use client'
 import type { FC } from 'react'
-import React from 'react'
+import React, { useCallback } from 'react'
 import { useTranslation } from 'react-i18next'
 
 import cn from '@/utils/classnames'
@@ -11,7 +11,7 @@ import Switch from '@/app/components/base/switch'
 import Tooltip from '@/app/components/base/tooltip'
 import type { RetrievalConfig } from '@/types/app'
 import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
-import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
 import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 import {
   DEFAULT_WEIGHTED_SCORE,
@@ -19,6 +19,7 @@ import {
   WeightedScoreEnum,
 } from '@/models/datasets'
 import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score'
+import Toast from '@/app/components/base/toast'
 
 type Props = {
   type: RETRIEVE_METHOD
@@ -38,6 +39,24 @@ const RetrievalParamConfig: FC<Props> = ({
     defaultModel: rerankDefaultModel,
     modelList: rerankModelList,
   } = useModelListAndDefaultModel(ModelTypeEnum.rerank)
+
+  const {
+    currentModel,
+  } = useCurrentProviderAndModel(
+    rerankModelList,
+    rerankDefaultModel
+      ? {
+        ...rerankDefaultModel,
+        provider: rerankDefaultModel.provider.provider,
+      }
+      : undefined,
+  )
+
+  const handleDisabledSwitchClick = useCallback(() => {
+    if (!currentModel)
+      Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
+  }, [currentModel, rerankDefaultModel, t])
+
   const isHybridSearch = type === RETRIEVE_METHOD.hybrid
 
   const rerankModel = (() => {
@@ -99,16 +118,22 @@ const RetrievalParamConfig: FC<Props> = ({
         <div>
           <div className='flex h-8 items-center text-[13px] font-medium text-gray-900 space-x-2'>
             {canToggleRerankModalEnable && (
-              <Switch
-                size='md'
-                defaultValue={value.reranking_enable}
-                onChange={(v) => {
-                  onChange({
-                    ...value,
-                    reranking_enable: v,
-                  })
-                }}
-              />
+              <div
+                className='flex items-center'
+                onClick={handleDisabledSwitchClick}
+              >
+                <Switch
+                  size='md'
+                  defaultValue={currentModel ? value.reranking_enable : false}
+                  onChange={(v) => {
+                    onChange({
+                      ...value,
+                      reranking_enable: v,
+                    })
+                  }}
+                  disabled={!currentModel}
+                />
+              </div>
             )}
             <div className='flex items-center'>
               <span className='mr-0.5'>{t('common.modelProvider.rerankModel.key')}</span>

+ 55 - 0
web/app/components/workflow/hooks/use-workflow-start-run.tsx

@@ -1,17 +1,25 @@
 import { useCallback } from 'react'
 import { useStoreApi } from 'reactflow'
+import { useTranslation } from 'react-i18next'
 import { useWorkflowStore } from '../store'
 import {
   BlockEnum,
   WorkflowRunningStatus,
 } from '../types'
+import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
+import type { Node } from '../types'
+import { useWorkflow } from './use-workflow'
 import {
   useIsChatMode,
   useNodesSyncDraft,
   useWorkflowInteractions,
   useWorkflowRun,
 } from './index'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
 import { useFeaturesStore } from '@/app/components/base/features/hooks'
+import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
+import Toast from '@/app/components/base/toast'
 
 export const useWorkflowStartRun = () => {
   const store = useStoreApi()
@@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => {
   const isChatMode = useIsChatMode()
   const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
   const { handleRun } = useWorkflowRun()
+  const { isFromStartNode } = useWorkflow()
   const { doSyncWorkflowDraft } = useNodesSyncDraft()
+  const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
+  const { t } = useTranslation()
+  const {
+    modelList: rerankModelList,
+    defaultModel: rerankDefaultModel,
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
+
+  const {
+    currentModel,
+  } = useCurrentProviderAndModel(
+    rerankModelList,
+    rerankDefaultModel
+      ? {
+        ...rerankDefaultModel,
+        provider: rerankDefaultModel.provider.provider,
+      }
+      : undefined,
+  )
 
   const handleWorkflowStartRunInWorkflow = useCallback(async () => {
     const {
@@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => {
     const { getNodes } = store.getState()
     const nodes = getNodes()
     const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
+    const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
+      node.data.type === BlockEnum.KnowledgeRetrieval,
+    )
     const startVariables = startNode?.data.variables || []
     const fileSettings = featuresStore!.getState().features.file
     const {
@@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => {
       setShowEnvPanel,
     } = workflowStore.getState()
 
+    if (knowledgeRetrievalNodes.length > 0) {
+      for (const node of knowledgeRetrievalNodes) {
+        if (isFromStartNode(node.id)) {
+          const res = checkKnowledgeRetrievalValid(node.data, t)
+          if (!res.isValid || !currentModel || !rerankDefaultModel) {
+            const errorMessage = res.errorMessage
+            if (errorMessage) {
+              Toast.notify({
+                type: 'error',
+                message: errorMessage,
+              })
+              return false
+            }
+            else {
+              Toast.notify({
+                type: 'error',
+                message: t('appDebug.datasetConfig.rerankModelRequired'),
+              })
+              return false
+            }
+          }
+        }
+      }
+    }
+
     setShowEnvPanel(false)
 
     if (showDebugAndPreviewPanel) {

+ 28 - 0
web/app/components/workflow/hooks/use-workflow.ts

@@ -235,6 +235,33 @@ export const useWorkflow = () => {
     return nodes.filter(node => node.parentId === nodeId)
   }, [store])
 
+  const isFromStartNode = useCallback((nodeId: string) => {
+    const { getNodes } = store.getState()
+    const nodes = getNodes()
+    const currentNode = nodes.find(node => node.id === nodeId)
+
+    if (!currentNode)
+      return false
+
+    if (currentNode.data.type === BlockEnum.Start)
+      return true
+
+    const checkPreviousNodes = (node: Node) => {
+      const previousNodes = getBeforeNodeById(node.id)
+
+      for (const prevNode of previousNodes) {
+        if (prevNode.data.type === BlockEnum.Start)
+          return true
+        if (checkPreviousNodes(prevNode))
+          return true
+      }
+
+      return false
+    }
+
+    return checkPreviousNodes(currentNode)
+  }, [store, getBeforeNodeById])
+
   const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => {
     const { getNodes, setNodes } = store.getState()
     const afterNodes = getAfterNodesInSameBranch(nodeId)
@@ -389,6 +416,7 @@ export const useWorkflow = () => {
     checkParallelLimit,
     checkNestedParallelLimit,
     isValidConnection,
+    isFromStartNode,
     formatTimeFromNow,
     getNode,
     getBeforeNodeById,

+ 1 - 0
web/i18n/en-US/workflow.ts

@@ -172,6 +172,7 @@ const translation = {
   },
   errorMsg: {
     fieldRequired: '{{field}} is required',
+    rerankModelRequired: 'Before turning on the Rerank Model, please confirm that the model has been successfully configured in the settings.',
     authRequired: 'Authorization is required',
     invalidJson: '{{field}} is invalid JSON',
     fields: {

+ 1 - 0
web/i18n/zh-Hans/workflow.ts

@@ -172,6 +172,7 @@ const translation = {
   },
   errorMsg: {
     fieldRequired: '{{field}} 不能为空',
+    rerankModelRequired: '开启 Rerank 模型前,请务必确认模型已在设置中成功配置。',
     authRequired: '请先授权',
     invalidJson: '{{field}} 是非法的 JSON',
     fields: {