import type { FC, ReactNode, } from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import type { DefaultModel, FormValue, } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelList, } from '@/app/components/header/account-setting/model-provider-page/hooks' import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger' import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger' import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import LLMParamsPanel from './llm-params-panel' import TTSParamsPanel from './tts-params-panel' import { useProviderContext } from '@/context/provider-context' import cn from '@/utils/classnames' export type ModelParameterModalProps = { popupClassName?: string portalToFollowElemContentClassName?: string isAdvancedMode: boolean value: any setModel: (model: any) => void renderTrigger?: (v: TriggerProps) => ReactNode readonly?: boolean isInWorkflow?: boolean isAgentStrategy?: boolean scope?: string } const ModelParameterModal: FC = ({ popupClassName, portalToFollowElemContentClassName, isAdvancedMode, value, setModel, renderTrigger, readonly, isInWorkflow, isAgentStrategy, scope = ModelTypeEnum.textGeneration, }) => { const { t } = useTranslation() const { isAPIKeySet } = useProviderContext() const [open, setOpen] = useState(false) const scopeArray = scope.split('&') const scopeFeatures = useMemo(() => { if (scopeArray.includes('all')) return [] return scopeArray.filter(item => ![ ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.moderation, ModelTypeEnum.speech2text, ModelTypeEnum.tts, ].includes(item as ModelTypeEnum)) }, [scopeArray]) const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration) const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding) const { data: rerankList } = useModelList(ModelTypeEnum.rerank) const { data: moderationList } = useModelList(ModelTypeEnum.moderation) const { data: sttList } = useModelList(ModelTypeEnum.speech2text) const { data: ttsList } = useModelList(ModelTypeEnum.tts) const scopedModelList = useMemo(() => { const resultList: any[] = [] if (scopeArray.includes('all')) { return [ ...textGenerationList, ...textEmbeddingList, ...rerankList, ...sttList, ...ttsList, ...moderationList, ] } if (scopeArray.includes(ModelTypeEnum.textGeneration)) return textGenerationList if (scopeArray.includes(ModelTypeEnum.textEmbedding)) return textEmbeddingList if (scopeArray.includes(ModelTypeEnum.rerank)) return rerankList if (scopeArray.includes(ModelTypeEnum.moderation)) return moderationList if (scopeArray.includes(ModelTypeEnum.speech2text)) return sttList if (scopeArray.includes(ModelTypeEnum.tts)) return ttsList return resultList }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList]) const { currentProvider, currentModel } = useMemo(() => { const currentProvider = scopedModelList.find(item => item.provider === value?.provider) const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model) return { currentProvider, currentModel, } }, [scopedModelList, value?.provider, value?.model]) const hasDeprecated = useMemo(() => { return !currentProvider || !currentModel }, [currentModel, currentProvider]) const modelDisabled = useMemo(() => { return currentModel?.status !== ModelStatusEnum.active }, [currentModel?.status]) const disabled = useMemo(() => { return !isAPIKeySet || hasDeprecated || modelDisabled }, [hasDeprecated, isAPIKeySet, modelDisabled]) const handleChangeModel = ({ provider, model }: DefaultModel) => { const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider) const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model) const model_type = targetModelItem?.model_type as string setModel({ provider, model, model_type, ...(model_type === ModelTypeEnum.textGeneration ? { mode: targetModelItem?.model_properties.mode as string, completion_params: {}, } : {}), }) } const handleLLMParamsChange = (newParams: FormValue) => { const newValue = { ...(value?.completionParams || {}), completion_params: newParams, } setModel({ ...value, ...newValue, }) } const handleTTSParamsChange = (language: string, voice: string) => { setModel({ ...value, language, voice, }) } return (
{ if (readonly) return setOpen(v => !v) }} className='block' > { renderTrigger ? renderTrigger({ open, disabled, modelDisabled, hasDeprecated, currentProvider, currentModel, providerName: value?.provider, modelId: value?.model, }) : (isAgentStrategy ? : ) }
{t('common.modelProvider.model').toLocaleUpperCase()}
{(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
)} {currentModel?.model_type === ModelTypeEnum.textGeneration && ( )} {currentModel?.model_type === ModelTypeEnum.tts && ( )}
) } export default ModelParameterModal