index.tsx 8.3 KB


  1. import type { FC } from 'react'
  2. import { useState } from 'react'
  3. import { useTranslation } from 'react-i18next'
  4. import ModelSelector from '../model-selector'
  5. import type {
  6. BackendModel, ProviderEnum,
  7. } from '../declarations'
  8. import Tooltip from '@/app/components/base/tooltip'
  9. import { HelpCircle, Settings01 } from '@/app/components/base/icons/src/vender/line/general'
  10. import {
  11. PortalToFollowElem,
  12. PortalToFollowElemContent,
  13. PortalToFollowElemTrigger,
  14. } from '@/app/components/base/portal-to-follow-elem'
  15. import { useProviderContext } from '@/context/provider-context'
  16. import { updateDefaultModel } from '@/service/common'
  17. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  18. import { useToastContext } from '@/app/components/base/toast'
  19. import Button from '@/app/components/base/button'
  20. type SystemModelProps = {
  21. onUpdate: () => void
  22. }
  23. const SystemModel: FC<SystemModelProps> = ({
  24. onUpdate,
  25. }) => {
  26. const { t } = useTranslation()
  27. const {
  28. textGenerationDefaultModel,
  29. mutateTextGenerationDefaultModel,
  30. embeddingsDefaultModel,
  31. mutateEmbeddingsDefaultModel,
  32. speech2textDefaultModel,
  33. mutateSpeech2textDefaultModel,
  34. rerankDefaultModel,
  35. mutateRerankDefaultModel,
  36. } = useProviderContext()
  37. const { notify } = useToastContext()
  38. const [open, setOpen] = useState(false)
  39. const [selectedModel, setSelectedModel] = useState<Record<ModelType, { providerName: ProviderEnum; modelName: string } | undefined>>({
  40. [ModelType.textGeneration]: textGenerationDefaultModel && { providerName: textGenerationDefaultModel.model_provider.provider_name, modelName: textGenerationDefaultModel.model_name },
  41. [ModelType.embeddings]: embeddingsDefaultModel && { providerName: embeddingsDefaultModel.model_provider.provider_name, modelName: embeddingsDefaultModel.model_name },
  42. [ModelType.speech2text]: speech2textDefaultModel && { providerName: speech2textDefaultModel.model_provider.provider_name, modelName: speech2textDefaultModel.model_name },
  43. [ModelType.reranking]: rerankDefaultModel && { providerName: rerankDefaultModel.model_provider.provider_name, modelName: rerankDefaultModel.model_name },
  44. })
  45. const mutateDefaultModel = (types: ModelType[]) => {
  46. types.forEach((type) => {
  47. if (type === ModelType.textGeneration)
  48. mutateTextGenerationDefaultModel()
  49. if (type === ModelType.embeddings)
  50. mutateEmbeddingsDefaultModel()
  51. if (type === ModelType.speech2text)
  52. mutateSpeech2textDefaultModel()
  53. if (type === ModelType.reranking)
  54. mutateRerankDefaultModel()
  55. })
  56. }
  57. const handleChangeDefaultModel = async (type: ModelType, v: BackendModel) => {
  58. setSelectedModel({
  59. ...selectedModel,
  60. [type]: {
  61. providerName: v.model_provider.provider_name,
  62. modelName: v.model_name,
  63. },
  64. })
  65. }
  66. const handleSave = async () => {
  67. const kesArray = Object.keys(selectedModel) as ModelType[]
  68. const res = await updateDefaultModel({
  69. url: '/workspaces/current/default-model',
  70. body: {
  71. model_settings: kesArray.map((key) => {
  72. return {
  73. model_type: key,
  74. provider_name: selectedModel?.[key]?.providerName,
  75. model_name: selectedModel?.[key]?.modelName,
  76. }
  77. }),
  78. },
  79. })
  80. if (res.result === 'success') {
  81. notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
  82. mutateDefaultModel(kesArray)
  83. }
  84. }
  85. return (
  86. <PortalToFollowElem
  87. open={open}
  88. onOpenChange={setOpen}
  89. placement='bottom-end'
  90. offset={{
  91. mainAxis: 4,
  92. crossAxis: 8,
  93. }}
  94. >
  95. <PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
  96. <div className={`
  97. flex items-center px-2 h-6 text-xs text-gray-700 cursor-pointer bg-white rounded-md border-[0.5px] border-gray-200 shadow-xs
  98. hover:bg-gray-100 hover:shadow-none
  99. ${open && 'bg-gray-100 shadow-none'}
  100. `}>
  101. <Settings01 className='mr-1 w-3 h-3 text-gray-500' />
  102. {t('common.modelProvider.systemModelSettings')}
  103. </div>
  104. </PortalToFollowElemTrigger>
  105. <PortalToFollowElemContent className='z-50'>
  106. <div className='pt-4 w-[360px] rounded-xl border-[0.5px] border-black/5 bg-white shadow-xl'>
  107. <div className='px-6 py-1'>
  108. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  109. {t('common.modelProvider.systemReasoningModel.key')}
  110. <Tooltip
  111. selector='model-page-system-reasoning-model-tip'
  112. htmlContent={
  113. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.systemReasoningModel.tip')}</div>
  114. }
  115. >
  116. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  117. </Tooltip>
  118. </div>
  119. <div>
  120. <ModelSelector
  121. value={selectedModel[ModelType.textGeneration]}
  122. modelType={ModelType.textGeneration}
  123. onChange={v => handleChangeDefaultModel(ModelType.textGeneration, v)}
  124. />
  125. </div>
  126. </div>
  127. <div className='px-6 py-1'>
  128. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  129. {t('common.modelProvider.embeddingModel.key')}
  130. <Tooltip
  131. selector='model-page-system-embedding-model-tip'
  132. htmlContent={
  133. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.embeddingModel.tip')}</div>
  134. }
  135. >
  136. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  137. </Tooltip>
  138. </div>
  139. <div>
  140. <ModelSelector
  141. value={selectedModel[ModelType.embeddings]}
  142. modelType={ModelType.embeddings}
  143. onChange={v => handleChangeDefaultModel(ModelType.embeddings, v)}
  144. />
  145. </div>
  146. </div>
  147. <div className='px-6 py-1'>
  148. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  149. {t('common.modelProvider.rerankModel.key')}
  150. <Tooltip
  151. selector='model-page-system-rerankModel-model-tip'
  152. htmlContent={
  153. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.rerankModel.tip')}</div>
  154. }
  155. >
  156. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  157. </Tooltip>
  158. </div>
  159. <div>
  160. <ModelSelector
  161. value={selectedModel[ModelType.reranking]}
  162. modelType={ModelType.reranking}
  163. onChange={v => handleChangeDefaultModel(ModelType.reranking, v)}
  164. whenEmptyGoToSetting
  165. onUpdate={onUpdate}
  166. />
  167. </div>
  168. </div>
  169. <div className='px-6 py-1'>
  170. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  171. {t('common.modelProvider.speechToTextModel.key')}
  172. <Tooltip
  173. selector='model-page-system-speechToText-model-tip'
  174. htmlContent={
  175. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.speechToTextModel.tip')}</div>
  176. }
  177. >
  178. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  179. </Tooltip>
  180. </div>
  181. <div>
  182. <ModelSelector
  183. value={selectedModel[ModelType.speech2text]}
  184. modelType={ModelType.speech2text}
  185. onChange={v => handleChangeDefaultModel(ModelType.speech2text, v)}
  186. />
  187. </div>
  188. </div>
  189. <div className='flex items-center justify-end px-6 py-4'>
  190. <Button
  191. className='mr-2 !h-8 !text-[13px]'
  192. onClick={() => setOpen(false)}
  193. >
  194. {t('common.operation.cancel')}
  195. </Button>
  196. <Button
  197. type='primary'
  198. className='!h-8 !text-[13px]'
  199. onClick={handleSave}
  200. >
  201. {t('common.operation.save')}
  202. </Button>
  203. </div>
  204. </div>
  205. </PortalToFollowElemContent>
  206. </PortalToFollowElem>
  207. )
  208. }
  209. export default SystemModel