hooks.ts 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import {
  2. useCallback,
  3. useEffect,
  4. useMemo,
  5. useState,
  6. } from 'react'
  7. import useSWR, { useSWRConfig } from 'swr'
  8. import { useContext } from 'use-context-selector'
  9. import type {
  10. CustomConfigurationModelFixedFields,
  11. DefaultModel,
  12. DefaultModelResponse,
  13. Model,
  14. ModelProvider,
  15. ModelTypeEnum,
  16. } from './declarations'
  17. import {
  18. ConfigurationMethodEnum,
  19. CustomConfigurationStatusEnum,
  20. ModelStatusEnum,
  21. } from './declarations'
  22. import I18n from '@/context/i18n'
  23. import {
  24. fetchDefaultModal,
  25. fetchModelList,
  26. fetchModelProviderCredentials,
  27. fetchModelProviders,
  28. getPayUrl,
  29. } from '@/service/common'
  30. import { useProviderContext } from '@/context/provider-context'
  31. import {
  32. useMarketplacePlugins,
  33. } from '@/app/components/plugins/marketplace/hooks'
  34. import type { Plugin } from '@/app/components/plugins/types'
  35. import { PluginType } from '@/app/components/plugins/types'
  36. import { getMarketplacePluginsByCollectionId } from '@/app/components/plugins/marketplace/utils'
  37. import { useModalContextSelector } from '@/context/modal-context'
  38. import { useEventEmitterContextContext } from '@/context/event-emitter'
  39. import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
  40. type UseDefaultModelAndModelList = (
  41. defaultModel: DefaultModelResponse | undefined,
  42. modelList: Model[],
  43. ) => [DefaultModel | undefined, (model: DefaultModel) => void]
  44. export const useSystemDefaultModelAndModelList: UseDefaultModelAndModelList = (
  45. defaultModel,
  46. modelList,
  47. ) => {
  48. const currentDefaultModel = useMemo(() => {
  49. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider.provider)
  50. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  51. const currentDefaultModel = currentProvider && currentModel && {
  52. model: currentModel.model,
  53. provider: currentProvider.provider,
  54. }
  55. return currentDefaultModel
  56. }, [defaultModel, modelList])
  57. const [defaultModelState, setDefaultModelState] = useState<DefaultModel | undefined>(currentDefaultModel)
  58. const handleDefaultModelChange = useCallback((model: DefaultModel) => {
  59. setDefaultModelState(model)
  60. }, [])
  61. useEffect(() => {
  62. setDefaultModelState(currentDefaultModel)
  63. }, [currentDefaultModel])
  64. return [defaultModelState, handleDefaultModelChange]
  65. }
  66. export const useLanguage = () => {
  67. const { locale } = useContext(I18n)
  68. return locale.replace('-', '_')
  69. }
  70. export const useProviderCredentialsAndLoadBalancing = (
  71. provider: string,
  72. configurationMethod: ConfigurationMethodEnum,
  73. configured?: boolean,
  74. currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
  75. ) => {
  76. const { data: predefinedFormSchemasValue, mutate: mutatePredefined } = useSWR(
  77. (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured)
  78. ? `/workspaces/current/model-providers/${provider}/credentials`
  79. : null,
  80. fetchModelProviderCredentials,
  81. )
  82. const { data: customFormSchemasValue, mutate: mutateCustomized } = useSWR(
  83. (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields)
  84. ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}`
  85. : null,
  86. fetchModelProviderCredentials,
  87. )
  88. const credentials = useMemo(() => {
  89. return configurationMethod === ConfigurationMethodEnum.predefinedModel
  90. ? predefinedFormSchemasValue?.credentials
  91. : customFormSchemasValue?.credentials
  92. ? {
  93. ...customFormSchemasValue?.credentials,
  94. ...currentCustomConfigurationModelFixedFields,
  95. }
  96. : undefined
  97. }, [
  98. configurationMethod,
  99. currentCustomConfigurationModelFixedFields,
  100. customFormSchemasValue?.credentials,
  101. predefinedFormSchemasValue?.credentials,
  102. ])
  103. const mutate = useMemo(() => () => {
  104. mutatePredefined()
  105. mutateCustomized()
  106. }, [mutateCustomized, mutatePredefined])
  107. return {
  108. credentials,
  109. loadBalancing: (configurationMethod === ConfigurationMethodEnum.predefinedModel
  110. ? predefinedFormSchemasValue
  111. : customFormSchemasValue
  112. )?.load_balancing,
  113. mutate,
  114. }
  115. // as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined])
  116. }
  117. export const useModelList = (type: ModelTypeEnum) => {
  118. const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${type}`, fetchModelList)
  119. return {
  120. data: data?.data || [],
  121. mutate,
  122. isLoading,
  123. }
  124. }
  125. export const useDefaultModel = (type: ModelTypeEnum) => {
  126. const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${type}`, fetchDefaultModal)
  127. return {
  128. data: data?.data,
  129. mutate,
  130. isLoading,
  131. }
  132. }
  133. export const useCurrentProviderAndModel = (modelList: Model[], defaultModel?: DefaultModel) => {
  134. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider)
  135. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  136. return {
  137. currentProvider,
  138. currentModel,
  139. }
  140. }
  141. export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultModel?: DefaultModel) => {
  142. const { textGenerationModelList } = useProviderContext()
  143. const activeTextGenerationModelList = textGenerationModelList.filter(model => model.status === ModelStatusEnum.active)
  144. const {
  145. currentProvider,
  146. currentModel,
  147. } = useCurrentProviderAndModel(textGenerationModelList, defaultModel)
  148. return {
  149. currentProvider,
  150. currentModel,
  151. textGenerationModelList,
  152. activeTextGenerationModelList,
  153. }
  154. }
  155. export const useModelListAndDefaultModel = (type: ModelTypeEnum) => {
  156. const { data: modelList } = useModelList(type)
  157. const { data: defaultModel } = useDefaultModel(type)
  158. return {
  159. modelList,
  160. defaultModel,
  161. }
  162. }
  163. export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeEnum) => {
  164. const { modelList, defaultModel } = useModelListAndDefaultModel(type)
  165. const { currentProvider, currentModel } = useCurrentProviderAndModel(
  166. modelList,
  167. { provider: defaultModel?.provider.provider || '', model: defaultModel?.model || '' },
  168. )
  169. return {
  170. modelList,
  171. defaultModel,
  172. currentProvider,
  173. currentModel,
  174. }
  175. }
  176. export const useUpdateModelList = () => {
  177. const { mutate } = useSWRConfig()
  178. const updateModelList = useCallback((type: ModelTypeEnum) => {
  179. mutate(`/workspaces/current/models/model-types/${type}`)
  180. }, [mutate])
  181. return updateModelList
  182. }
  183. export const useAnthropicBuyQuota = () => {
  184. const [loading, setLoading] = useState(false)
  185. const handleGetPayUrl = async () => {
  186. if (loading)
  187. return
  188. setLoading(true)
  189. try {
  190. const res = await getPayUrl('/workspaces/current/model-providers/anthropic/checkout-url')
  191. window.location.href = res.url
  192. }
  193. finally {
  194. setLoading(false)
  195. }
  196. }
  197. return handleGetPayUrl
  198. }
  199. export const useModelProviders = () => {
  200. const { data: providersData, mutate, isLoading } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
  201. return {
  202. data: providersData?.data || [],
  203. mutate,
  204. isLoading,
  205. }
  206. }
  207. export const useUpdateModelProviders = () => {
  208. const { mutate } = useSWRConfig()
  209. const updateModelProviders = useCallback(() => {
  210. mutate('/workspaces/current/model-providers')
  211. }, [mutate])
  212. return updateModelProviders
  213. }
  214. export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: string) => {
  215. const exclude = useMemo(() => {
  216. return providers.map(provider => provider.provider.replace(/(.+)\/([^/]+)$/, '$1'))
  217. }, [providers])
  218. const [collectionPlugins, setCollectionPlugins] = useState<Plugin[]>([])
  219. const {
  220. plugins,
  221. queryPlugins,
  222. queryPluginsWithDebounced,
  223. isLoading,
  224. } = useMarketplacePlugins()
  225. const getCollectionPlugins = useCallback(async () => {
  226. const collectionPlugins = await getMarketplacePluginsByCollectionId('__model-settings-pinned-models')
  227. setCollectionPlugins(collectionPlugins)
  228. }, [])
  229. useEffect(() => {
  230. getCollectionPlugins()
  231. }, [getCollectionPlugins])
  232. useEffect(() => {
  233. if (searchText) {
  234. queryPluginsWithDebounced({
  235. query: searchText,
  236. category: PluginType.model,
  237. exclude,
  238. type: 'plugin',
  239. sortBy: 'install_count',
  240. sortOrder: 'DESC',
  241. })
  242. }
  243. else {
  244. queryPlugins({
  245. query: '',
  246. category: PluginType.model,
  247. type: 'plugin',
  248. pageSize: 1000,
  249. exclude,
  250. sortBy: 'install_count',
  251. sortOrder: 'DESC',
  252. })
  253. }
  254. }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])
  255. const allPlugins = useMemo(() => {
  256. const allPlugins = [...collectionPlugins.filter(plugin => !exclude.includes(plugin.plugin_id))]
  257. if (plugins?.length) {
  258. for (let i = 0; i < plugins.length; i++) {
  259. const plugin = plugins[i]
  260. if (plugin.type !== 'bundle' && !allPlugins.find(p => p.plugin_id === plugin.plugin_id))
  261. allPlugins.push(plugin)
  262. }
  263. }
  264. return allPlugins
  265. }, [plugins, collectionPlugins, exclude])
  266. return {
  267. plugins: allPlugins,
  268. isLoading,
  269. }
  270. }
  271. export const useModelModalHandler = () => {
  272. const setShowModelModal = useModalContextSelector(state => state.setShowModelModal)
  273. const updateModelProviders = useUpdateModelProviders()
  274. const updateModelList = useUpdateModelList()
  275. const { eventEmitter } = useEventEmitterContextContext()
  276. return (
  277. provider: ModelProvider,
  278. configurationMethod: ConfigurationMethodEnum,
  279. CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
  280. ) => {
  281. setShowModelModal({
  282. payload: {
  283. currentProvider: provider,
  284. currentConfigurationMethod: configurationMethod,
  285. currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields,
  286. },
  287. onSaveCallback: () => {
  288. updateModelProviders()
  289. provider.supported_model_types.forEach((type) => {
  290. updateModelList(type)
  291. })
  292. if (configurationMethod === ConfigurationMethodEnum.customizableModel
  293. && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
  294. eventEmitter?.emit({
  295. type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
  296. payload: provider.provider,
  297. } as any)
  298. if (CustomConfigurationModelFixedFields?.__model_type)
  299. updateModelList(CustomConfigurationModelFixedFields.__model_type)
  300. }
  301. },
  302. })
  303. }
  304. }