index.tsx 14 KB


  1. import type { FC } from 'react'
  2. import React, { Fragment, useEffect, useState } from 'react'
  3. import useSWR from 'swr'
  4. import { Popover, Transition } from '@headlessui/react'
  5. import { useTranslation } from 'react-i18next'
  6. import _ from 'lodash-es'
  7. import cn from 'classnames'
  8. import ModelModal from '../model-modal'
  9. import cohereConfig from '../configs/cohere'
  10. import s from './style.module.css'
  11. import type { BackendModel, FormValue, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
  12. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  13. import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
  14. import { Check, LinkExternal01, SearchLg } from '@/app/components/base/icons/src/vender/line/general'
  15. import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
  16. import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
  17. import Tooltip from '@/app/components/base/tooltip'
  18. import ModelIcon from '@/app/components/app/configuration/config-model/model-icon'
  19. import ModelName from '@/app/components/app/configuration/config-model/model-name'
  20. import ProviderName from '@/app/components/app/configuration/config-model/provider-name'
  21. import { useProviderContext } from '@/context/provider-context'
  22. import ModelModeTypeLabel from '@/app/components/app/configuration/config-model/model-mode-type-label'
  23. import type { ModelModeType } from '@/types/app'
  24. import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
  25. import { useModalContext } from '@/context/modal-context'
  26. import { useEventEmitterContextContext } from '@/context/event-emitter'
  27. import { fetchDefaultModal, setModelProvider } from '@/service/common'
  28. import { useToastContext } from '@/app/components/base/toast'
  29. type Props = {
  30. value: {
  31. providerName: ProviderEnum
  32. modelName: string
  33. } | undefined
  34. modelType: ModelType
  35. isShowModelModeType?: boolean
  36. isShowAddModel?: boolean
  37. supportAgentThought?: boolean
  38. onChange: (value: BackendModel) => void
  39. popClassName?: string
  40. readonly?: boolean
  41. triggerIconSmall?: boolean
  42. whenEmptyGoToSetting?: boolean
  43. onUpdate?: () => void
  44. }
  45. type ModelOption = {
  46. type: 'model'
  47. value: string
  48. providerName: ProviderEnum
  49. modelDisplayName: string
  50. model_mode: ModelModeType
  51. } | {
  52. type: 'provider'
  53. value: ProviderEnum
  54. }
  55. const ModelSelector: FC<Props> = ({
  56. value,
  57. modelType,
  58. isShowModelModeType,
  59. isShowAddModel,
  60. supportAgentThought,
  61. onChange,
  62. popClassName,
  63. readonly,
  64. triggerIconSmall,
  65. whenEmptyGoToSetting,
  66. onUpdate,
  67. }) => {
  68. const { t } = useTranslation()
  69. const { setShowAccountSettingModal } = useModalContext()
  70. const {
  71. textGenerationModelList,
  72. embeddingsModelList,
  73. speech2textModelList,
  74. rerankModelList,
  75. agentThoughtModelList,
  76. updateModelList,
  77. } = useProviderContext()
  78. const [search, setSearch] = useState('')
  79. const modelList = supportAgentThought
  80. ? agentThoughtModelList
  81. : ({
  82. [ModelType.textGeneration]: textGenerationModelList,
  83. [ModelType.embeddings]: embeddingsModelList,
  84. [ModelType.speech2text]: speech2textModelList,
  85. [ModelType.reranking]: rerankModelList,
  86. })[modelType]
  87. const currModel = modelList.find(item => item.model_name === value?.modelName && item.model_provider.provider_name === value.providerName)
  88. const allModelNames = (() => {
  89. if (!search)
  90. return {}
  91. const res: Record<string, string> = {}
  92. modelList.forEach(({ model_name, model_display_name }) => {
  93. res[model_name] = model_display_name
  94. })
  95. return res
  96. })()
  97. const filteredModelList = search
  98. ? modelList.filter(({ model_name }) => {
  99. if (allModelNames[model_name].includes(search))
  100. return true
  101. return false
  102. })
  103. : modelList
  104. const hasRemoved = (value && value.modelName && value.providerName) && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName)
  105. const modelOptions: ModelOption[] = (() => {
  106. const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name))
  107. const res: ModelOption[] = []
  108. providers.forEach((providerName) => {
  109. res.push({
  110. type: 'provider',
  111. value: providerName,
  112. })
  113. const models = filteredModelList.filter(m => m.model_provider.provider_name === providerName)
  114. models.forEach(({ model_name, model_display_name, model_mode }) => {
  115. res.push({
  116. type: 'model',
  117. providerName,
  118. value: model_name,
  119. modelDisplayName: model_display_name,
  120. model_mode,
  121. })
  122. })
  123. })
  124. return res
  125. })()
  126. const { eventEmitter } = useEventEmitterContextContext()
  127. const [showRerankModal, setShowRerankModal] = useState(false)
  128. const [shouldFetchRerankDefaultModel, setShouldFetchRerankDefaultModel] = useState(false)
  129. const { notify } = useToastContext()
  130. const { data: rerankDefaultModel } = useSWR(shouldFetchRerankDefaultModel ? '/workspaces/current/default-model?model_type=reranking' : null, fetchDefaultModal)
  131. const handleOpenRerankModal = (e: React.MouseEvent<HTMLDivElement>) => {
  132. e.stopPropagation()
  133. setShowRerankModal(true)
  134. }
  135. const handleRerankModalSave = async (originValue?: FormValue) => {
  136. if (originValue) {
  137. try {
  138. eventEmitter?.emit('provider-save')
  139. const res = await setModelProvider({
  140. url: `/workspaces/current/model-providers/${cohereConfig.modal.key}`,
  141. body: {
  142. config: originValue,
  143. },
  144. })
  145. if (res.result === 'success') {
  146. notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
  147. updateModelList(ModelType.reranking)
  148. setShowRerankModal(false)
  149. setShouldFetchRerankDefaultModel(true)
  150. if (onUpdate)
  151. onUpdate()
  152. }
  153. eventEmitter?.emit('')
  154. }
  155. catch (e) {
  156. eventEmitter?.emit('')
  157. }
  158. }
  159. }
  160. useEffect(() => {
  161. if (rerankDefaultModel && whenEmptyGoToSetting)
  162. onChange(rerankDefaultModel)
  163. }, [rerankDefaultModel])
  164. return (
  165. <div className=''>
  166. <Popover className='relative'>
  167. <Popover.Button className={cn('flex items-center px-2.5 w-full h-9 rounded-lg', readonly ? '!cursor-auto bg-gray-100 opacity-50' : 'bg-gray-100', hasRemoved && '!bg-[#FEF3F2]')}>
  168. {
  169. ({ open }) => (
  170. <>
  171. {
  172. (value && value.modelName && value.providerName)
  173. ? (
  174. <>
  175. <ModelIcon
  176. className={cn('mr-1.5', !triggerIconSmall && 'w-5 h-5')}
  177. modelId={value.modelName}
  178. providerName={value.providerName}
  179. />
  180. <div className='mr-1.5 grow flex items-center text-left text-sm text-gray-900 truncate'>
  181. <ModelName modelId={value.modelName} modelDisplayName={currModel?.model_display_name || value.modelName} />
  182. {isShowModelModeType && (
  183. <ModelModeTypeLabel className='ml-2' type={currModel?.model_mode as ModelModeType} />
  184. )}
  185. </div>
  186. </>
  187. )
  188. : whenEmptyGoToSetting
  189. ? (
  190. <div className='grow flex items-center h-9 justify-between' onClick={handleOpenRerankModal}>
  191. <div className='flex items-center text-[13px] font-medium text-primary-500'>
  192. <CubeOutline className='mr-1.5 w-4 h-4' />
  193. {t('common.modelProvider.selector.rerankTip')}
  194. </div>
  195. <LinkExternal01 className='w-3 h-3 text-gray-500' />
  196. </div>
  197. )
  198. : (
  199. <div className='grow text-left text-sm text-gray-800 opacity-60'>{t('common.modelProvider.selectModel')}</div>
  200. )
  201. }
  202. {
  203. hasRemoved && (
  204. <Tooltip
  205. selector='model-selector-remove-tip'
  206. htmlContent={
  207. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.selector.tip')}</div>
  208. }
  209. >
  210. <AlertCircle className='mr-1 w-4 h-4 text-[#F04438]' />
  211. </Tooltip>
  212. )
  213. }
  214. {
  215. !readonly && !whenEmptyGoToSetting && (
  216. <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />
  217. )
  218. }
  219. {
  220. whenEmptyGoToSetting && (value && value.modelName && value.providerName) && (
  221. <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />
  222. )
  223. }
  224. </>
  225. )
  226. }
  227. </Popover.Button>
  228. {!readonly && (
  229. <Transition
  230. as={Fragment}
  231. leave='transition ease-in duration-100'
  232. leaveFrom='opacity-100'
  233. leaveTo='opacity-0'
  234. >
  235. <Popover.Panel className={cn(popClassName, isShowModelModeType ? 'max-w-[312px]' : 'max-w-[260px]', 'absolute top-10 p-1 min-w-[232px] max-h-[366px] bg-white border-[0.5px] border-gray-200 rounded-lg shadow-lg overflow-auto z-10')}>
  236. <div className='px-2 pt-2 pb-1'>
  237. <div className='flex items-center px-2 h-8 bg-gray-100 rounded-lg'>
  238. <div className='mr-1.5 p-[1px]'><SearchLg className='w-[14px] h-[14px] text-gray-400' /></div>
  239. <div className='grow px-0.5'>
  240. <input
  241. value={search}
  242. onChange={e => setSearch(e.target.value)}
  243. className={`
  244. block w-full h-8 bg-transparent text-[13px] text-gray-700
  245. outline-none appearance-none border-none
  246. `}
  247. placeholder={t('common.modelProvider.searchModel') || ''}
  248. />
  249. </div>
  250. {
  251. search && (
  252. <div className='ml-1 p-0.5 cursor-pointer' onClick={() => setSearch('')}>
  253. <XCircle className='w-3 h-3 text-gray-400' />
  254. </div>
  255. )
  256. }
  257. </div>
  258. </div>
  259. {
  260. modelOptions.map((model) => {
  261. if (model.type === 'provider') {
  262. return (
  263. <div
  264. className='px-3 pt-2 pb-1 text-xs font-medium text-gray-500'
  265. key={`${model.type}-${model.value}`}
  266. >
  267. <ProviderName provideName={model.value} />
  268. </div>
  269. )
  270. }
  271. if (model.type === 'model') {
  272. return (
  273. <Popover.Button
  274. key={`${model.providerName}-${model.value}`}
  275. className={`${s.optionItem}
  276. flex items-center px-3 w-full h-8 rounded-lg hover:bg-gray-50
  277. ${!readonly ? 'cursor-pointer' : 'cursor-auto'}
  278. ${(value?.providerName === model.providerName && value?.modelName === model.value) && 'bg-gray-50'}
  279. `}
  280. onClick={() => {
  281. const selectedModel = modelList.find((item) => {
  282. return item.model_name === model.value && item.model_provider.provider_name === model.providerName
  283. })
  284. onChange(selectedModel as BackendModel)
  285. }}
  286. >
  287. <ModelIcon
  288. className='mr-2 shrink-0'
  289. modelId={model.value}
  290. providerName={model.providerName}
  291. />
  292. <div className='mr-2 grow flex items-center text-left text-sm text-gray-900 truncate'>
  293. <ModelName modelId={model.value} modelDisplayName={model.modelDisplayName} />
  294. {isShowModelModeType && (
  295. <ModelModeTypeLabel className={`${s.modelModeLabel} ml-2`} type={model.model_mode} />
  296. )}
  297. </div>
  298. { (value?.providerName === model.providerName && value?.modelName === model.value) && <Check className='shrink-0 w-4 h-4 text-primary-600' /> }
  299. </Popover.Button>
  300. )
  301. }
  302. return null
  303. })
  304. }
  305. {modelList.length !== 0 && (search && filteredModelList.length === 0) && (
  306. <div className='px-3 pt-1.5 h-[30px] text-center text-xs text-gray-500'>{t('common.modelProvider.noModelFound', { model: search })}</div>
  307. )}
  308. {isShowAddModel && (
  309. <div
  310. className='border-t flex items-center h-9 pl-3 text-xs text-[#155EEF] cursor-pointer'
  311. style={{
  312. borderColor: 'rgba(0, 0, 0, 0.05)',
  313. }}
  314. onClick={() => setShowAccountSettingModal({ payload: 'provider' })}
  315. >
  316. <CubeOutline className='w-4 h-4 mr-2' />
  317. <div>{t('common.model.addMoreModel')}</div>
  318. </div>
  319. )}
  320. </Popover.Panel>
  321. </Transition>
  322. )}
  323. </Popover>
  324. <ModelModal
  325. isShow={showRerankModal}
  326. modelModal={cohereConfig.modal}
  327. onCancel={() => setShowRerankModal(false)}
  328. onSave={handleRerankModalSave}
  329. mode={'add'}
  330. />
  331. </div>
  332. )
  333. }
  334. export default ModelSelector