Sfoglia il codice sorgente

Improve ModelTypeEnum type (#3051)

Nanguan Lin 1 anno fa
parent
commit
718ac3f83b

+ 3 - 2
web/app/components/app/configuration/config/index.tsx

@@ -26,6 +26,7 @@ import { useModalContext } from '@/context/modal-context'
 import ConfigParamModal from '@/app/components/app/configuration/toolbox/annotation/config-param-modal'
 import AnnotationFullModal from '@/app/components/billing/annotation-full/modal'
 import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 const Config: FC = () => {
   const {
@@ -61,8 +62,8 @@ const Config: FC = () => {
     setModerationConfig,
   } = useContext(ConfigContext)
   const isChatApp = mode === AppType.chat
-  const { data: speech2textDefaultModel } = useDefaultModel(4)
-  const { data: text2speechDefaultModel } = useDefaultModel(5)
+  const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text)
+  const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.tts)
   const { setShowModerationSettingModal } = useModalContext()
   const formattingChangedDispatcher = useFormattingChangedDispatcher()
 

+ 2 - 1
web/app/components/app/configuration/dataset-config/params-config/index.tsx

@@ -20,6 +20,7 @@ import {
 } from '@/app/components/base/icons/src/public/common'
 import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
 import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 const ParamsConfig: FC = () => {
   const { t } = useTranslation()
@@ -41,7 +42,7 @@ const ParamsConfig: FC = () => {
     modelList: rerankModelList,
     defaultModel: rerankDefaultModel,
     currentModel: isRerankDefaultModelVaild,
-  } = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
 
   const rerankModel = (() => {
     if (tempDataSetConfigs.reranking_model) {

+ 3 - 2
web/app/components/app/configuration/dataset-config/settings-modal/index.tsx

@@ -22,6 +22,7 @@ import {
   useModelList,
   useModelListAndDefaultModelAndCurrentProviderAndModel,
 } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 type SettingsModalProps = {
   currentDataset: DataSet
@@ -42,12 +43,12 @@ const SettingsModal: FC<SettingsModalProps> = ({
   onCancel,
   onSave,
 }) => {
-  const { data: embeddingsModelList } = useModelList(2)
+  const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding)
   const {
     modelList: rerankModelList,
     defaultModel: rerankDefaultModel,
     currentModel: isRerankDefaultModelVaild,
-  } = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
   const { t } = useTranslation()
   const { notify } = useToastContext()
   const ref = useRef(null)

+ 2 - 2
web/app/components/app/configuration/debug/index.tsx

@@ -31,7 +31,7 @@ import { IS_CE_EDITION } from '@/config'
 import type { Inputs } from '@/models/debug'
 import { fetchFileUploadConfig } from '@/service/common'
 import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
-import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 import type { ModelParameterModalProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
 import { Plus } from '@/app/components/base/icons/src/vender/line/general'
 import { useEventEmitterContextContext } from '@/context/event-emitter'
@@ -84,7 +84,7 @@ const Debug: FC<IDebug> = ({
     setVisionConfig,
   } = useContext(ConfigContext)
   const { eventEmitter } = useEventEmitterContextContext()
-  const { data: text2speechDefaultModel } = useDefaultModel(5)
+  const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
   const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig)
   useEffect(() => {
     setAutoFreeze(false)

+ 2 - 1
web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx

@@ -11,6 +11,7 @@ import type { AnnotationReplyConfig } from '@/models/debug'
 import { ANNOTATION_DEFAULT } from '@/config'
 import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
 import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 type Props = {
   appId: string
@@ -36,7 +37,7 @@ const ConfigParamModal: FC<Props> = ({
     modelList: embeddingsModelList,
     defaultModel: embeddingsDefaultModel,
     currentModel: isEmbeddingsDefaultModelValid,
-  } = useModelListAndDefaultModelAndCurrentProviderAndModel(2)
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textEmbedding)
   const [annotationConfig, setAnnotationConfig] = useState(oldAnnotationConfig)
 
   const [isLoading, setLoading] = useState(false)

+ 2 - 1
web/app/components/datasets/common/retrieval-method-config/index.tsx

@@ -10,6 +10,7 @@ import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/ve
 import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files'
 import { useProviderContext } from '@/context/provider-context'
 import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 type Props = {
   value: RetrievalConfig
@@ -22,7 +23,7 @@ const RetrievalMethodConfig: FC<Props> = ({
 }) => {
   const { t } = useTranslation()
   const { supportRetrievalMethods } = useProviderContext()
-  const { data: rerankDefaultModel } = useDefaultModel(3)
+  const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank)
   const value = (() => {
     if (!passValue.reranking_model.reranking_model_name) {
       return {

+ 2 - 1
web/app/components/datasets/common/retrieval-param-config/index.tsx

@@ -12,6 +12,7 @@ import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
 import type { RetrievalConfig } from '@/types/app'
 import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
 import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 type Props = {
   type: RETRIEVE_METHOD
@@ -30,7 +31,7 @@ const RetrievalParamConfig: FC<Props> = ({
   const {
     defaultModel: rerankDefaultModel,
     modelList: rerankModelList,
-  } = useModelListAndDefaultModel(3)
+  } = useModelListAndDefaultModel(ModelTypeEnum.rerank)
 
   const rerankModel = (() => {
     if (value.reranking_model) {

+ 2 - 1
web/app/components/datasets/create/index.tsx

@@ -2,6 +2,7 @@
 import React, { useCallback, useEffect, useState } from 'react'
 import { useTranslation } from 'react-i18next'
 import AppUnavailable from '../../base/app-unavailable'
+import { ModelTypeEnum } from '../../header/account-setting/model-provider-page/declarations'
 import StepsNavBar from './steps-nav-bar'
 import StepOne from './step-one'
 import StepTwo from './step-two'
@@ -28,7 +29,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
   const [fileList, setFiles] = useState<FileItem[]>([])
   const [result, setResult] = useState<createDocumentResponse | undefined>()
   const [hasError, setHasError] = useState(false)
-  const { data: embeddingsDefaultModel } = useDefaultModel(2)
+  const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
 
   const [notionPages, setNotionPages] = useState<NotionPage[]>([])
   const updateNotionPages = (value: NotionPage[]) => {

+ 2 - 1
web/app/components/datasets/create/step-two/index.tsx

@@ -43,6 +43,7 @@ import Tooltip from '@/app/components/base/tooltip'
 import TooltipPlus from '@/app/components/base/tooltip-plus'
 import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
 import { LanguagesSupported } from '@/i18n/language'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 type ValueOf<T> = T[keyof T]
 type StepTwoProps = {
@@ -275,7 +276,7 @@ const StepTwo = ({
     modelList: rerankModelList,
     defaultModel: rerankDefaultModel,
     currentModel: isRerankDefaultModelVaild,
-  } = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
   const getCreationParams = () => {
     let params
     if (segmentationType === SegmentType.CUSTOM && overlap > max) {

+ 2 - 1
web/app/components/datasets/documents/detail/settings/index.tsx

@@ -14,6 +14,7 @@ import StepTwo from '@/app/components/datasets/create/step-two'
 import AccountSetting from '@/app/components/header/account-setting'
 import AppUnavailable from '@/app/components/base/app-unavailable'
 import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 type DocumentSettingsProps = {
   datasetId: string
@@ -26,7 +27,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
   const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
   const [hasError, setHasError] = useState(false)
   const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
-  const { data: embeddingsDefaultModel } = useDefaultModel(2)
+  const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
 
   const saveHandler = () => router.push(`/datasets/${datasetId}/documents/${documentId}`)
 

+ 2 - 1
web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx

@@ -3,6 +3,7 @@ import type { FC } from 'react'
 import React, { useRef, useState } from 'react'
 import { useTranslation } from 'react-i18next'
 import Toast from '../../base/toast'
+import { ModelTypeEnum } from '../../header/account-setting/model-provider-page/declarations'
 import { XClose } from '@/app/components/base/icons/src/vender/line/general'
 import type { RetrievalConfig } from '@/types/app'
 import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
@@ -39,7 +40,7 @@ const ModifyRetrievalModal: FC<Props> = ({
     modelList: rerankModelList,
     defaultModel: rerankDefaultModel,
     currentModel: isRerankDefaultModelVaild,
-  } = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
 
   const handleSave = () => {
     if (

+ 3 - 2
web/app/components/datasets/settings/form/index.tsx

@@ -24,6 +24,7 @@ import {
   useModelList,
   useModelListAndDefaultModelAndCurrentProviderAndModel,
 } from '@/app/components/header/account-setting/model-provider-page/hooks'
+import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 
 const rowClass = `
   flex justify-between py-4 flex-wrap gap-y-2
@@ -63,8 +64,8 @@ const Form = () => {
     modelList: rerankModelList,
     defaultModel: rerankDefaultModel,
     currentModel: isRerankDefaultModelVaild,
-  } = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
-  const { data: embeddingModelList } = useModelList(2)
+  } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
+  const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
 
   const handleSave = async () => {
     if (loading)

+ 10 - 19
web/app/components/header/account-setting/model-provider-page/hooks.ts

@@ -11,10 +11,11 @@ import type {
   DefaultModel,
   DefaultModelResponse,
   Model,
+
+  ModelTypeEnum,
 } from './declarations'
 import {
   ConfigurateMethodEnum,
-  ModelTypeEnum,
 } from './declarations'
 import I18n from '@/context/i18n'
 import {
@@ -99,17 +100,8 @@ export const useProviderCrenditialsFormSchemasValue = (
   return value
 }
 
-export type ModelTypeIndex = 1 | 2 | 3 | 4 | 5
-export const MODEL_TYPE_MAPS = {
-  1: ModelTypeEnum.textGeneration,
-  2: ModelTypeEnum.textEmbedding,
-  3: ModelTypeEnum.rerank,
-  4: ModelTypeEnum.speech2text,
-  5: ModelTypeEnum.tts,
-}
-
-export const useModelList = (type: ModelTypeIndex) => {
-  const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${MODEL_TYPE_MAPS[type]}`, fetchModelList)
+export const useModelList = (type: ModelTypeEnum) => {
+  const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${type}`, fetchModelList)
 
   return {
     data: data?.data || [],
@@ -118,8 +110,8 @@ export const useModelList = (type: ModelTypeIndex) => {
   }
 }
 
-export const useDefaultModel = (type: ModelTypeIndex) => {
-  const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${MODEL_TYPE_MAPS[type]}`, fetchDefaultModal)
+export const useDefaultModel = (type: ModelTypeEnum) => {
+  const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${type}`, fetchDefaultModal)
 
   return {
     data: data?.data,
@@ -152,7 +144,7 @@ export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultMode
   }
 }
 
-export const useModelListAndDefaultModel = (type: ModelTypeIndex) => {
+export const useModelListAndDefaultModel = (type: ModelTypeEnum) => {
   const { data: modelList } = useModelList(type)
   const { data: defaultModel } = useDefaultModel(type)
 
@@ -162,7 +154,7 @@ export const useModelListAndDefaultModel = (type: ModelTypeIndex) => {
   }
 }
 
-export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeIndex) => {
+export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeEnum) => {
   const { modelList, defaultModel } = useModelListAndDefaultModel(type)
   const { currentProvider, currentModel } = useCurrentProviderAndModel(
     modelList,
@@ -180,9 +172,8 @@ export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: Mode
 export const useUpdateModelList = () => {
   const { mutate } = useSWRConfig()
 
-  const updateModelList = useCallback((type: ModelTypeIndex | ModelTypeEnum) => {
-    const modelType = typeof type === 'number' ? MODEL_TYPE_MAPS[type] : type
-    mutate(`/workspaces/current/models/model-types/${modelType}`)
+  const updateModelList = useCallback((type: ModelTypeEnum) => {
+    mutate(`/workspaces/current/models/model-types/${type}`)
   }, [mutate])
 
   return updateModelList

+ 6 - 5
web/app/components/header/account-setting/model-provider-page/index.tsx

@@ -10,6 +10,7 @@ import type {
 import {
   ConfigurateMethodEnum,
   CustomConfigurationStatusEnum,
+  ModelTypeEnum,
 } from './declarations'
 import {
   useDefaultModel,
@@ -26,11 +27,11 @@ const ModelProviderPage = () => {
   const { eventEmitter } = useEventEmitterContextContext()
   const updateModelProviders = useUpdateModelProviders()
   const updateModelList = useUpdateModelList()
-  const { data: textGenerationDefaultModel } = useDefaultModel(1)
-  const { data: embeddingsDefaultModel } = useDefaultModel(2)
-  const { data: rerankDefaultModel } = useDefaultModel(3)
-  const { data: speech2textDefaultModel } = useDefaultModel(4)
-  const { data: ttsDefaultModel } = useDefaultModel(5)
+  const { data: textGenerationDefaultModel } = useDefaultModel(ModelTypeEnum.textGeneration)
+  const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
+  const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank)
+  const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text)
+  const { data: ttsDefaultModel } = useDefaultModel(ModelTypeEnum.tts)
   const { modelProviders: providers } = useProviderContext()
   const { setShowModelModal } = useModalContext()
   const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel

+ 4 - 4
web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx

@@ -42,10 +42,10 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
   const { notify } = useToastContext()
   const { textGenerationModelList } = useProviderContext()
   const updateModelList = useUpdateModelList()
-  const { data: embeddingModelList } = useModelList(2)
-  const { data: rerankModelList } = useModelList(3)
-  const { data: speech2textModelList } = useModelList(4)
-  const { data: ttsModelList } = useModelList(5)
+  const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
+  const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
+  const { data: speech2textModelList } = useModelList(ModelTypeEnum.speech2text)
+  const { data: ttsModelList } = useModelList(ModelTypeEnum.tts)
   const [changedModelTypes, setChangedModelTypes] = useState<ModelTypeEnum[]>([])
   const [currentTextGenerationDefaultModel, changeCurrentTextGenerationDefaultModel] = useSystemDefaultModelAndModelList(textGenerationDefaultModel, textGenerationModelList)
   const [currentEmbeddingsDefaultModel, changeCurrentEmbeddingsDefaultModel] = useSystemDefaultModelAndModelList(embeddingsDefaultModel, embeddingModelList)