index.tsx 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. 'use client'
  2. import type { FC } from 'react'
  3. import { memo, useState } from 'react'
  4. import { useTranslation } from 'react-i18next'
  5. import { useContext } from 'use-context-selector'
  6. import cn from 'classnames'
  7. import { Settings04 } from '@/app/components/base/icons/src/vender/line/general'
  8. import ConfigContext from '@/context/debug-configuration'
  9. import TopKItem from '@/app/components/base/param-item/top-k-item'
  10. import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item'
  11. import Modal from '@/app/components/base/modal'
  12. import Button from '@/app/components/base/button'
  13. import RadioCard from '@/app/components/base/radio-card/simple'
  14. import { RETRIEVE_TYPE } from '@/types/app'
  15. import Toast from '@/app/components/base/toast'
  16. import { DATASET_DEFAULT } from '@/config'
  17. import {
  18. MultiPathRetrieval,
  19. NTo1Retrieval,
  20. } from '@/app/components/base/icons/src/public/common'
  21. import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
  22. import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
  23. import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  24. const ParamsConfig: FC = () => {
  25. const { t } = useTranslation()
  26. const [open, setOpen] = useState(false)
  27. const {
  28. datasetConfigs,
  29. setDatasetConfigs,
  30. } = useContext(ConfigContext)
  31. const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs)
  32. const type = tempDataSetConfigs.retrieval_model
  33. const setType = (value: RETRIEVE_TYPE) => {
  34. setTempDataSetConfigs({
  35. ...tempDataSetConfigs,
  36. retrieval_model: value,
  37. })
  38. }
  39. const {
  40. modelList: rerankModelList,
  41. defaultModel: rerankDefaultModel,
  42. currentModel: isRerankDefaultModelVaild,
  43. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
  44. const rerankModel = (() => {
  45. if (tempDataSetConfigs.reranking_model) {
  46. return {
  47. provider_name: tempDataSetConfigs.reranking_model.reranking_provider_name,
  48. model_name: tempDataSetConfigs.reranking_model.reranking_model_name,
  49. }
  50. }
  51. else if (rerankDefaultModel) {
  52. return {
  53. provider_name: rerankDefaultModel.provider.provider,
  54. model_name: rerankDefaultModel.model,
  55. }
  56. }
  57. })()
  58. const handleParamChange = (key: string, value: number) => {
  59. if (key === 'top_k') {
  60. setTempDataSetConfigs({
  61. ...tempDataSetConfigs,
  62. top_k: value,
  63. })
  64. }
  65. else if (key === 'score_threshold') {
  66. setTempDataSetConfigs({
  67. ...tempDataSetConfigs,
  68. score_threshold: value,
  69. })
  70. }
  71. }
  72. const handleSwitch = (key: string, enable: boolean) => {
  73. if (key === 'top_k')
  74. return
  75. setTempDataSetConfigs({
  76. ...tempDataSetConfigs,
  77. score_threshold_enabled: enable,
  78. })
  79. }
  80. const isValid = () => {
  81. let errMsg = ''
  82. if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
  83. if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (!rerankDefaultModel && isRerankDefaultModelVaild))
  84. errMsg = t('appDebug.datasetConfig.rerankModelRequired')
  85. }
  86. if (errMsg) {
  87. Toast.notify({
  88. type: 'error',
  89. message: errMsg,
  90. })
  91. }
  92. return !errMsg
  93. }
  94. const handleSave = () => {
  95. if (!isValid())
  96. return
  97. const config = { ...tempDataSetConfigs }
  98. if (config.retrieval_model === RETRIEVE_TYPE.multiWay && !config.reranking_model) {
  99. config.reranking_model = {
  100. reranking_provider_name: rerankDefaultModel?.provider?.provider,
  101. reranking_model_name: rerankDefaultModel?.model,
  102. } as any
  103. }
  104. setDatasetConfigs(config)
  105. setOpen(false)
  106. }
  107. return (
  108. <div>
  109. <div
  110. className={cn('flex items-center rounded-md h-7 px-3 space-x-1 text-gray-700 cursor-pointer hover:bg-gray-200', open && 'bg-gray-200')}
  111. onClick={() => {
  112. setTempDataSetConfigs({
  113. ...datasetConfigs,
  114. top_k: datasetConfigs.top_k || DATASET_DEFAULT.top_k,
  115. score_threshold: datasetConfigs.score_threshold || DATASET_DEFAULT.score_threshold,
  116. })
  117. setOpen(true)
  118. }}
  119. >
  120. <Settings04 className="w-[14px] h-[14px]" />
  121. <div className='text-xs font-medium'>
  122. {t('appDebug.datasetConfig.params')}
  123. </div>
  124. </div>
  125. {
  126. open && (
  127. <Modal
  128. isShow={open}
  129. onClose={() => {
  130. setOpen(false)
  131. }}
  132. className='sm:min-w-[528px]'
  133. wrapperClassName='z-50'
  134. title={t('appDebug.datasetConfig.settingTitle')}
  135. >
  136. <div className='mt-2 space-y-3'>
  137. <RadioCard
  138. icon={<NTo1Retrieval className='shrink-0 mr-3 w-9 h-9 rounded-lg' />}
  139. title={t('appDebug.datasetConfig.retrieveOneWay.title')}
  140. description={t('appDebug.datasetConfig.retrieveOneWay.description')}
  141. isChosen={type === RETRIEVE_TYPE.oneWay}
  142. onChosen={() => { setType(RETRIEVE_TYPE.oneWay) }}
  143. />
  144. <RadioCard
  145. icon={<MultiPathRetrieval className='shrink-0 mr-3 w-9 h-9 rounded-lg' />}
  146. title={t('appDebug.datasetConfig.retrieveMultiWay.title')}
  147. description={t('appDebug.datasetConfig.retrieveMultiWay.description')}
  148. isChosen={type === RETRIEVE_TYPE.multiWay}
  149. onChosen={() => { setType(RETRIEVE_TYPE.multiWay) }}
  150. />
  151. </div>
  152. {type === RETRIEVE_TYPE.multiWay && (
  153. <>
  154. <div className='mt-6'>
  155. <div className='leading-[32px] text-[13px] font-medium text-gray-900'>{t('common.modelProvider.rerankModel.key')}</div>
  156. <div>
  157. <ModelSelector
  158. defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
  159. onSelect={(v) => {
  160. setTempDataSetConfigs({
  161. ...tempDataSetConfigs,
  162. reranking_model: {
  163. reranking_provider_name: v.provider,
  164. reranking_model_name: v.model,
  165. },
  166. })
  167. }}
  168. modelList={rerankModelList}
  169. />
  170. </div>
  171. </div>
  172. <div className='mt-4 space-y-4'>
  173. <TopKItem
  174. value={tempDataSetConfigs.top_k}
  175. onChange={handleParamChange}
  176. enable={true}
  177. />
  178. <ScoreThresholdItem
  179. value={tempDataSetConfigs.score_threshold}
  180. onChange={handleParamChange}
  181. enable={tempDataSetConfigs.score_threshold_enabled}
  182. hasSwitch={true}
  183. onSwitchChange={handleSwitch}
  184. />
  185. </div>
  186. </>
  187. )}
  188. <div className='mt-6 flex justify-end'>
  189. <Button className='mr-2 flex-shrink-0' onClick={() => {
  190. setOpen(false)
  191. }}>{t('common.operation.cancel')}</Button>
  192. <Button type='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
  193. </div>
  194. </Modal>
  195. )
  196. }
  197. </div>
  198. )
  199. }
  200. export default memo(ParamsConfig)