index.tsx 8.1 KB


  1. import { useState } from 'react'
  2. import { useTranslation } from 'react-i18next'
  3. import ModelSelector from '../model-selector'
  4. import type {
  5. BackendModel, ProviderEnum,
  6. } from '../declarations'
  7. import Tooltip from '@/app/components/base/tooltip'
  8. import { HelpCircle, Settings01 } from '@/app/components/base/icons/src/vender/line/general'
  9. import {
  10. PortalToFollowElem,
  11. PortalToFollowElemContent,
  12. PortalToFollowElemTrigger,
  13. } from '@/app/components/base/portal-to-follow-elem'
  14. import { useProviderContext } from '@/context/provider-context'
  15. import { updateDefaultModel } from '@/service/common'
  16. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  17. import { useToastContext } from '@/app/components/base/toast'
  18. import Button from '@/app/components/base/button'
  19. const SystemModel = () => {
  20. const { t } = useTranslation()
  21. const {
  22. textGenerationDefaultModel,
  23. mutateTextGenerationDefaultModel,
  24. embeddingsDefaultModel,
  25. mutateEmbeddingsDefaultModel,
  26. speech2textDefaultModel,
  27. mutateSpeech2textDefaultModel,
  28. rerankDefaultModel,
  29. mutateRerankDefaultModel,
  30. } = useProviderContext()
  31. const { notify } = useToastContext()
  32. const [open, setOpen] = useState(false)
  33. const [selectedModel, setSelectedModel] = useState<Record<ModelType, { providerName: ProviderEnum; modelName: string } | undefined>>({
  34. [ModelType.textGeneration]: textGenerationDefaultModel && { providerName: textGenerationDefaultModel.model_provider.provider_name, modelName: textGenerationDefaultModel.model_name },
  35. [ModelType.embeddings]: embeddingsDefaultModel && { providerName: embeddingsDefaultModel.model_provider.provider_name, modelName: embeddingsDefaultModel.model_name },
  36. [ModelType.speech2text]: speech2textDefaultModel && { providerName: speech2textDefaultModel.model_provider.provider_name, modelName: speech2textDefaultModel.model_name },
  37. [ModelType.reranking]: rerankDefaultModel && { providerName: rerankDefaultModel.model_provider.provider_name, modelName: rerankDefaultModel.model_name },
  38. })
  39. const mutateDefaultModel = (types: ModelType[]) => {
  40. types.forEach((type) => {
  41. if (type === ModelType.textGeneration)
  42. mutateTextGenerationDefaultModel()
  43. if (type === ModelType.embeddings)
  44. mutateEmbeddingsDefaultModel()
  45. if (type === ModelType.speech2text)
  46. mutateSpeech2textDefaultModel()
  47. if (type === ModelType.reranking)
  48. mutateRerankDefaultModel()
  49. })
  50. }
  51. const handleChangeDefaultModel = async (type: ModelType, v: BackendModel) => {
  52. setSelectedModel({
  53. ...selectedModel,
  54. [type]: {
  55. providerName: v.model_provider.provider_name,
  56. modelName: v.model_name,
  57. },
  58. })
  59. }
  60. const handleSave = async () => {
  61. const kesArray = Object.keys(selectedModel) as ModelType[]
  62. const res = await updateDefaultModel({
  63. url: '/workspaces/current/default-model',
  64. body: {
  65. model_settings: kesArray.map((key) => {
  66. return {
  67. model_type: key,
  68. provider_name: selectedModel?.[key]?.providerName,
  69. model_name: selectedModel?.[key]?.modelName,
  70. }
  71. }),
  72. },
  73. })
  74. if (res.result === 'success') {
  75. notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
  76. mutateDefaultModel(kesArray)
  77. }
  78. }
  79. return (
  80. <PortalToFollowElem
  81. open={open}
  82. onOpenChange={setOpen}
  83. placement='bottom-end'
  84. offset={{
  85. mainAxis: 4,
  86. crossAxis: 8,
  87. }}
  88. >
  89. <PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
  90. <div className={`
  91. flex items-center px-2 h-6 text-xs text-gray-700 cursor-pointer rounded-md border-[0.5px] border-gray-200 shadow-xs
  92. hover:bg-gray-100 hover:shadow-none
  93. ${open && 'bg-gray-100 shadow-none'}
  94. `}>
  95. <Settings01 className='mr-1 w-3 h-3 text-gray-500' />
  96. {t('common.modelProvider.systemModelSettings')}
  97. </div>
  98. </PortalToFollowElemTrigger>
  99. <PortalToFollowElemContent className='z-50'>
  100. <div className='pt-4 w-[360px] rounded-xl border-[0.5px] border-black/5 bg-white shadow-xl'>
  101. <div className='px-6 py-1'>
  102. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  103. {t('common.modelProvider.systemReasoningModel.key')}
  104. <Tooltip
  105. selector='model-page-system-reasoning-model-tip'
  106. htmlContent={
  107. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.systemReasoningModel.tip')}</div>
  108. }
  109. >
  110. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  111. </Tooltip>
  112. </div>
  113. <div>
  114. <ModelSelector
  115. value={selectedModel[ModelType.textGeneration]}
  116. modelType={ModelType.textGeneration}
  117. onChange={v => handleChangeDefaultModel(ModelType.textGeneration, v)}
  118. />
  119. </div>
  120. </div>
  121. <div className='px-6 py-1'>
  122. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  123. {t('common.modelProvider.embeddingModel.key')}
  124. <Tooltip
  125. selector='model-page-system-embedding-model-tip'
  126. htmlContent={
  127. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.embeddingModel.tip')}</div>
  128. }
  129. >
  130. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  131. </Tooltip>
  132. </div>
  133. <div>
  134. <ModelSelector
  135. value={selectedModel[ModelType.embeddings]}
  136. modelType={ModelType.embeddings}
  137. onChange={v => handleChangeDefaultModel(ModelType.embeddings, v)}
  138. />
  139. </div>
  140. </div>
  141. <div className='px-6 py-1'>
  142. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  143. {t('common.modelProvider.rerankModel.key')}
  144. <Tooltip
  145. selector='model-page-system-rerankModel-model-tip'
  146. htmlContent={
  147. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.rerankModel.tip')}</div>
  148. }
  149. >
  150. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  151. </Tooltip>
  152. </div>
  153. <div>
  154. <ModelSelector
  155. value={selectedModel[ModelType.reranking]}
  156. modelType={ModelType.reranking}
  157. onChange={v => handleChangeDefaultModel(ModelType.reranking, v)}
  158. />
  159. </div>
  160. </div>
  161. <div className='px-6 py-1'>
  162. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  163. {t('common.modelProvider.speechToTextModel.key')}
  164. <Tooltip
  165. selector='model-page-system-speechToText-model-tip'
  166. htmlContent={
  167. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.speechToTextModel.tip')}</div>
  168. }
  169. >
  170. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  171. </Tooltip>
  172. </div>
  173. <div>
  174. <ModelSelector
  175. value={selectedModel[ModelType.speech2text]}
  176. modelType={ModelType.speech2text}
  177. onChange={v => handleChangeDefaultModel(ModelType.speech2text, v)}
  178. />
  179. </div>
  180. </div>
  181. <div className='flex items-center justify-end px-6 py-4'>
  182. <Button
  183. className='mr-2 !h-8 !text-[13px]'
  184. onClick={() => setOpen(false)}
  185. >
  186. {t('common.operation.cancel')}
  187. </Button>
  188. <Button
  189. type='primary'
  190. className='!h-8 !text-[13px]'
  191. onClick={handleSave}
  192. >
  193. {t('common.operation.save')}
  194. </Button>
  195. </div>
  196. </div>
  197. </PortalToFollowElemContent>
  198. </PortalToFollowElem>
  199. )
  200. }
  201. export default SystemModel