provider-context.tsx 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. 'use client'
  2. import { createContext, useContext } from 'use-context-selector'
  3. import useSWR from 'swr'
  4. import { fetchDefaultModal, fetchModelList, fetchSupportRetrievalMethods } from '@/service/common'
  5. import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  6. import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
  7. import type { RETRIEVE_METHOD } from '@/types/app'
  8. const ProviderContext = createContext<{
  9. textGenerationModelList: BackendModel[]
  10. embeddingsModelList: BackendModel[]
  11. speech2textModelList: BackendModel[]
  12. rerankModelList: BackendModel[]
  13. agentThoughtModelList: BackendModel[]
  14. updateModelList: (type: ModelType) => void
  15. textGenerationDefaultModel?: BackendModel
  16. mutateTextGenerationDefaultModel: () => void
  17. embeddingsDefaultModel?: BackendModel
  18. mutateEmbeddingsDefaultModel: () => void
  19. speech2textDefaultModel?: BackendModel
  20. mutateSpeech2textDefaultModel: () => void
  21. rerankDefaultModel?: BackendModel
  22. isRerankDefaultModelVaild: boolean
  23. mutateRerankDefaultModel: () => void
  24. supportRetrievalMethods: RETRIEVE_METHOD[]
  25. }>({
  26. textGenerationModelList: [],
  27. embeddingsModelList: [],
  28. speech2textModelList: [],
  29. rerankModelList: [],
  30. agentThoughtModelList: [],
  31. updateModelList: () => {},
  32. textGenerationDefaultModel: undefined,
  33. mutateTextGenerationDefaultModel: () => {},
  34. speech2textDefaultModel: undefined,
  35. mutateSpeech2textDefaultModel: () => {},
  36. embeddingsDefaultModel: undefined,
  37. mutateEmbeddingsDefaultModel: () => {},
  38. rerankDefaultModel: undefined,
  39. isRerankDefaultModelVaild: false,
  40. mutateRerankDefaultModel: () => {},
  41. supportRetrievalMethods: [],
  42. })
  43. export const useProviderContext = () => useContext(ProviderContext)
  44. type ProviderContextProviderProps = {
  45. children: React.ReactNode
  46. }
  47. export const ProviderContextProvider = ({
  48. children,
  49. }: ProviderContextProviderProps) => {
  50. const { data: textGenerationDefaultModel, mutate: mutateTextGenerationDefaultModel } = useSWR('/workspaces/current/default-model?model_type=text-generation', fetchDefaultModal)
  51. const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
  52. const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
  53. const { data: rerankDefaultModel, mutate: mutateRerankDefaultModel } = useSWR('/workspaces/current/default-model?model_type=reranking', fetchDefaultModal)
  54. const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
  55. const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
  56. const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
  57. const { data: speech2textModelList, mutate: mutateSpeech2textModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.speech2text}`, fetchModelList)
  58. const { data: rerankModelList, mutate: mutateRerankModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.reranking}`, fetchModelList)
  59. const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
  60. const agentThoughtModelList = textGenerationModelList?.filter((item) => {
  61. return item.features?.includes(ModelFeature.agentThought)
  62. })
  63. const isRerankDefaultModelVaild = !!rerankModelList?.find(
  64. item => item.model_name === rerankDefaultModel?.model_name && item.model_provider.provider_name === rerankDefaultModel?.model_provider.provider_name,
  65. )
  66. const updateModelList = (type: ModelType) => {
  67. if (type === ModelType.textGeneration)
  68. mutateTextGenerationModelList()
  69. if (type === ModelType.embeddings)
  70. mutateEmbeddingsModelList()
  71. if (type === ModelType.speech2text)
  72. mutateSpeech2textModelList()
  73. if (type === ModelType.reranking)
  74. mutateRerankModelList()
  75. }
  76. return (
  77. <ProviderContext.Provider value={{
  78. textGenerationModelList: textGenerationModelList || [],
  79. embeddingsModelList: embeddingsModelList || [],
  80. speech2textModelList: speech2textModelList || [],
  81. rerankModelList: rerankModelList || [],
  82. agentThoughtModelList: agentThoughtModelList || [],
  83. updateModelList,
  84. textGenerationDefaultModel,
  85. mutateTextGenerationDefaultModel,
  86. embeddingsDefaultModel,
  87. mutateEmbeddingsDefaultModel,
  88. speech2textDefaultModel,
  89. mutateSpeech2textDefaultModel,
  90. rerankDefaultModel,
  91. isRerankDefaultModelVaild,
  92. mutateRerankDefaultModel,
  93. supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
  94. }}>
  95. {children}
  96. </ProviderContext.Provider>
  97. )
  98. }
  99. export default ProviderContext