provider-context.tsx 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. 'use client'
  2. import { createContext, useContext } from 'use-context-selector'
  3. import useSWR from 'swr'
  4. import { useEffect, useState } from 'react'
  5. import { fetchDefaultModal, fetchModelList, fetchSupportRetrievalMethods } from '@/service/common'
  6. import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  7. import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
  8. import type { RETRIEVE_METHOD } from '@/types/app'
  9. import { Plan, type UsagePlanInfo } from '@/app/components/billing/type'
  10. import { fetchCurrentPlanInfo } from '@/service/billing'
  11. import { parseCurrentPlan } from '@/app/components/billing/utils'
  12. import { defaultPlan } from '@/app/components/billing/config'
  13. const ProviderContext = createContext<{
  14. textGenerationModelList: BackendModel[]
  15. embeddingsModelList: BackendModel[]
  16. speech2textModelList: BackendModel[]
  17. rerankModelList: BackendModel[]
  18. agentThoughtModelList: BackendModel[]
  19. updateModelList: (type: ModelType) => void
  20. textGenerationDefaultModel?: BackendModel
  21. mutateTextGenerationDefaultModel: () => void
  22. embeddingsDefaultModel?: BackendModel
  23. mutateEmbeddingsDefaultModel: () => void
  24. speech2textDefaultModel?: BackendModel
  25. mutateSpeech2textDefaultModel: () => void
  26. rerankDefaultModel?: BackendModel
  27. isRerankDefaultModelVaild: boolean
  28. mutateRerankDefaultModel: () => void
  29. supportRetrievalMethods: RETRIEVE_METHOD[]
  30. plan: {
  31. type: Plan
  32. usage: UsagePlanInfo
  33. total: UsagePlanInfo
  34. }
  35. isFetchedPlan: boolean
  36. enableBilling: boolean
  37. }>({
  38. textGenerationModelList: [],
  39. embeddingsModelList: [],
  40. speech2textModelList: [],
  41. rerankModelList: [],
  42. agentThoughtModelList: [],
  43. updateModelList: () => {},
  44. textGenerationDefaultModel: undefined,
  45. mutateTextGenerationDefaultModel: () => {},
  46. speech2textDefaultModel: undefined,
  47. mutateSpeech2textDefaultModel: () => {},
  48. embeddingsDefaultModel: undefined,
  49. mutateEmbeddingsDefaultModel: () => {},
  50. rerankDefaultModel: undefined,
  51. isRerankDefaultModelVaild: false,
  52. mutateRerankDefaultModel: () => {},
  53. supportRetrievalMethods: [],
  54. plan: {
  55. type: Plan.sandbox,
  56. usage: {
  57. vectorSpace: 32,
  58. buildApps: 12,
  59. teamMembers: 1,
  60. },
  61. total: {
  62. vectorSpace: 200,
  63. buildApps: 50,
  64. teamMembers: 1,
  65. },
  66. },
  67. isFetchedPlan: false,
  68. enableBilling: false,
  69. })
  70. export const useProviderContext = () => useContext(ProviderContext)
  71. type ProviderContextProviderProps = {
  72. children: React.ReactNode
  73. }
  74. export const ProviderContextProvider = ({
  75. children,
  76. }: ProviderContextProviderProps) => {
  77. const { data: textGenerationDefaultModel, mutate: mutateTextGenerationDefaultModel } = useSWR('/workspaces/current/default-model?model_type=text-generation', fetchDefaultModal)
  78. const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
  79. const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
  80. const { data: rerankDefaultModel, mutate: mutateRerankDefaultModel } = useSWR('/workspaces/current/default-model?model_type=reranking', fetchDefaultModal)
  81. const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
  82. const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
  83. const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
  84. const { data: speech2textModelList, mutate: mutateSpeech2textModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.speech2text}`, fetchModelList)
  85. const { data: rerankModelList, mutate: mutateRerankModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.reranking}`, fetchModelList)
  86. const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
  87. const agentThoughtModelList = textGenerationModelList?.filter((item) => {
  88. return item.features?.includes(ModelFeature.agentThought)
  89. })
  90. const isRerankDefaultModelVaild = !!rerankModelList?.find(
  91. item => item.model_name === rerankDefaultModel?.model_name && item.model_provider.provider_name === rerankDefaultModel?.model_provider.provider_name,
  92. )
  93. const updateModelList = (type: ModelType) => {
  94. if (type === ModelType.textGeneration)
  95. mutateTextGenerationModelList()
  96. if (type === ModelType.embeddings)
  97. mutateEmbeddingsModelList()
  98. if (type === ModelType.speech2text)
  99. mutateSpeech2textModelList()
  100. if (type === ModelType.reranking)
  101. mutateRerankModelList()
  102. }
  103. const [plan, setPlan] = useState(defaultPlan)
  104. const [isFetchedPlan, setIsFetchedPlan] = useState(false)
  105. const [enableBilling, setEnableBilling] = useState(true)
  106. useEffect(() => {
  107. (async () => {
  108. const data = await fetchCurrentPlanInfo()
  109. const enabled = data.enabled
  110. setEnableBilling(enabled)
  111. if (enabled) {
  112. setPlan(parseCurrentPlan(data))
  113. setIsFetchedPlan(true)
  114. }
  115. })()
  116. }, [])
  117. return (
  118. <ProviderContext.Provider value={{
  119. textGenerationModelList: textGenerationModelList || [],
  120. embeddingsModelList: embeddingsModelList || [],
  121. speech2textModelList: speech2textModelList || [],
  122. rerankModelList: rerankModelList || [],
  123. agentThoughtModelList: agentThoughtModelList || [],
  124. updateModelList,
  125. textGenerationDefaultModel,
  126. mutateTextGenerationDefaultModel,
  127. embeddingsDefaultModel,
  128. mutateEmbeddingsDefaultModel,
  129. speech2textDefaultModel,
  130. mutateSpeech2textDefaultModel,
  131. rerankDefaultModel,
  132. isRerankDefaultModelVaild,
  133. mutateRerankDefaultModel,
  134. supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
  135. plan,
  136. isFetchedPlan,
  137. enableBilling,
  138. }}>
  139. {children}
  140. </ProviderContext.Provider>
  141. )
  142. }
  143. export default ProviderContext