use-workflow-run.ts 18 KB


  1. import { useCallback } from 'react'
  2. import {
  3. getIncomers,
  4. useReactFlow,
  5. useStoreApi,
  6. } from 'reactflow'
  7. import produce from 'immer'
  8. import { v4 as uuidV4 } from 'uuid'
  9. import { usePathname } from 'next/navigation'
  10. import { useWorkflowStore } from '../store'
  11. import { useNodesSyncDraft } from '../hooks'
  12. import type { Node } from '../types'
  13. import {
  14. NodeRunningStatus,
  15. WorkflowRunningStatus,
  16. } from '../types'
  17. import { useWorkflowUpdate } from './use-workflow-interactions'
  18. import { useStore as useAppStore } from '@/app/components/app/store'
  19. import type { IOtherOptions } from '@/service/base'
  20. import { ssePost } from '@/service/base'
  21. import {
  22. fetchPublishedWorkflow,
  23. stopWorkflowRun,
  24. } from '@/service/workflow'
  25. import { useFeaturesStore } from '@/app/components/base/features/hooks'
  26. import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager'
  27. export const useWorkflowRun = () => {
  28. const store = useStoreApi()
  29. const workflowStore = useWorkflowStore()
  30. const reactflow = useReactFlow()
  31. const featuresStore = useFeaturesStore()
  32. const { doSyncWorkflowDraft } = useNodesSyncDraft()
  33. const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
  34. const pathname = usePathname()
  35. const handleBackupDraft = useCallback(() => {
  36. const {
  37. getNodes,
  38. edges,
  39. } = store.getState()
  40. const { getViewport } = reactflow
  41. const {
  42. backupDraft,
  43. setBackupDraft,
  44. environmentVariables,
  45. } = workflowStore.getState()
  46. const { features } = featuresStore!.getState()
  47. if (!backupDraft) {
  48. setBackupDraft({
  49. nodes: getNodes(),
  50. edges,
  51. viewport: getViewport(),
  52. features,
  53. environmentVariables,
  54. })
  55. doSyncWorkflowDraft()
  56. }
  57. }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft])
  58. const handleLoadBackupDraft = useCallback(() => {
  59. const {
  60. backupDraft,
  61. setBackupDraft,
  62. setEnvironmentVariables,
  63. } = workflowStore.getState()
  64. if (backupDraft) {
  65. const {
  66. nodes,
  67. edges,
  68. viewport,
  69. features,
  70. environmentVariables,
  71. } = backupDraft
  72. handleUpdateWorkflowCanvas({
  73. nodes,
  74. edges,
  75. viewport,
  76. })
  77. setEnvironmentVariables(environmentVariables)
  78. featuresStore!.setState({ features })
  79. setBackupDraft(undefined)
  80. }
  81. }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore])
  82. const handleRun = useCallback(async (
  83. params: any,
  84. callback?: IOtherOptions,
  85. ) => {
  86. const {
  87. getNodes,
  88. setNodes,
  89. } = store.getState()
  90. const newNodes = produce(getNodes(), (draft) => {
  91. draft.forEach((node) => {
  92. node.data.selected = false
  93. node.data._runningStatus = undefined
  94. })
  95. })
  96. setNodes(newNodes)
  97. await doSyncWorkflowDraft()
  98. const {
  99. onWorkflowStarted,
  100. onWorkflowFinished,
  101. onNodeStarted,
  102. onNodeFinished,
  103. onIterationStart,
  104. onIterationNext,
  105. onIterationFinish,
  106. onError,
  107. ...restCallback
  108. } = callback || {}
  109. workflowStore.setState({ historyWorkflowData: undefined })
  110. const appDetail = useAppStore.getState().appDetail
  111. const workflowContainer = document.getElementById('workflow-container')
  112. const {
  113. clientWidth,
  114. clientHeight,
  115. } = workflowContainer!
  116. let url = ''
  117. if (appDetail?.mode === 'advanced-chat')
  118. url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run`
  119. if (appDetail?.mode === 'workflow')
  120. url = `/apps/${appDetail.id}/workflows/draft/run`
  121. let prevNodeId = ''
  122. const {
  123. setWorkflowRunningData,
  124. } = workflowStore.getState()
  125. setWorkflowRunningData({
  126. result: {
  127. status: WorkflowRunningStatus.Running,
  128. },
  129. tracing: [],
  130. resultText: '',
  131. })
  132. let ttsUrl = ''
  133. let ttsIsPublic = false
  134. if (params.token) {
  135. ttsUrl = '/text-to-audio'
  136. ttsIsPublic = true
  137. }
  138. else if (params.appId) {
  139. if (pathname.search('explore/installed') > -1)
  140. ttsUrl = `/installed-apps/${params.appId}/text-to-audio`
  141. else
  142. ttsUrl = `/apps/${params.appId}/text-to-audio`
  143. }
  144. const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', (_: any): any => {})
  145. ssePost(
  146. url,
  147. {
  148. body: params,
  149. },
  150. {
  151. onWorkflowStarted: (params) => {
  152. const { task_id, data } = params
  153. const {
  154. workflowRunningData,
  155. setWorkflowRunningData,
  156. } = workflowStore.getState()
  157. const {
  158. edges,
  159. setEdges,
  160. } = store.getState()
  161. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  162. draft.task_id = task_id
  163. draft.result = {
  164. ...draft?.result,
  165. ...data,
  166. status: WorkflowRunningStatus.Running,
  167. }
  168. }))
  169. const newEdges = produce(edges, (draft) => {
  170. draft.forEach((edge) => {
  171. edge.data = {
  172. ...edge.data,
  173. _run: false,
  174. }
  175. })
  176. })
  177. setEdges(newEdges)
  178. if (onWorkflowStarted)
  179. onWorkflowStarted(params)
  180. },
  181. onWorkflowFinished: (params) => {
  182. const { data } = params
  183. const {
  184. workflowRunningData,
  185. setWorkflowRunningData,
  186. } = workflowStore.getState()
  187. const isStringOutput = data.outputs && Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string'
  188. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  189. draft.result = {
  190. ...draft.result,
  191. ...data,
  192. } as any
  193. if (isStringOutput) {
  194. draft.resultTabActive = true
  195. draft.resultText = data.outputs[Object.keys(data.outputs)[0]]
  196. }
  197. }))
  198. prevNodeId = ''
  199. if (onWorkflowFinished)
  200. onWorkflowFinished(params)
  201. },
  202. onError: (params) => {
  203. const {
  204. workflowRunningData,
  205. setWorkflowRunningData,
  206. } = workflowStore.getState()
  207. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  208. draft.result = {
  209. ...draft.result,
  210. status: WorkflowRunningStatus.Failed,
  211. }
  212. }))
  213. if (onError)
  214. onError(params)
  215. },
  216. onNodeStarted: (params) => {
  217. const { data } = params
  218. const {
  219. workflowRunningData,
  220. setWorkflowRunningData,
  221. } = workflowStore.getState()
  222. const {
  223. getNodes,
  224. setNodes,
  225. edges,
  226. setEdges,
  227. transform,
  228. } = store.getState()
  229. const nodes = getNodes()
  230. const node = nodes.find(node => node.id === data.node_id)
  231. if (node?.parentId) {
  232. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  233. const tracing = draft.tracing!
  234. const iterations = tracing.find(trace => trace.node_id === node?.parentId)
  235. const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1]
  236. currIteration?.push({
  237. ...data,
  238. status: NodeRunningStatus.Running,
  239. } as any)
  240. }))
  241. }
  242. else {
  243. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  244. draft.tracing!.push({
  245. ...data,
  246. status: NodeRunningStatus.Running,
  247. } as any)
  248. }))
  249. const {
  250. setViewport,
  251. } = reactflow
  252. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  253. const currentNode = nodes[currentNodeIndex]
  254. const position = currentNode.position
  255. const zoom = transform[2]
  256. if (!currentNode.parentId) {
  257. setViewport({
  258. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  259. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  260. zoom: transform[2],
  261. })
  262. }
  263. const newNodes = produce(nodes, (draft) => {
  264. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  265. })
  266. setNodes(newNodes)
  267. const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id)
  268. const newEdges = produce(edges, (draft) => {
  269. draft.forEach((edge) => {
  270. if (edge.target === data.node_id && incomeNodesId.includes(edge.source))
  271. edge.data = { ...edge.data, _run: true } as any
  272. })
  273. })
  274. setEdges(newEdges)
  275. }
  276. if (onNodeStarted)
  277. onNodeStarted(params)
  278. },
  279. onNodeFinished: (params) => {
  280. const { data } = params
  281. const {
  282. workflowRunningData,
  283. setWorkflowRunningData,
  284. } = workflowStore.getState()
  285. const {
  286. getNodes,
  287. setNodes,
  288. } = store.getState()
  289. const nodes = getNodes()
  290. const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId
  291. if (nodeParentId) {
  292. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  293. const tracing = draft.tracing!
  294. const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
  295. if (iterations && iterations.details) {
  296. const iterationIndex = data.execution_metadata?.iteration_index || 0
  297. if (!iterations.details[iterationIndex])
  298. iterations.details[iterationIndex] = []
  299. const currIteration = iterations.details[iterationIndex]
  300. const nodeIndex = currIteration.findIndex(node =>
  301. node.node_id === data.node_id && (
  302. node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id),
  303. )
  304. if (data.status === NodeRunningStatus.Succeeded) {
  305. if (nodeIndex !== -1) {
  306. currIteration[nodeIndex] = {
  307. ...currIteration[nodeIndex],
  308. ...data,
  309. } as any
  310. }
  311. else {
  312. currIteration.push({
  313. ...data,
  314. } as any)
  315. }
  316. }
  317. }
  318. }))
  319. }
  320. else {
  321. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  322. const currentIndex = draft.tracing!.findIndex((trace) => {
  323. if (!trace.execution_metadata?.parallel_id)
  324. return trace.node_id === data.node_id
  325. return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
  326. })
  327. if (currentIndex > -1 && draft.tracing) {
  328. draft.tracing[currentIndex] = {
  329. ...(draft.tracing[currentIndex].extras
  330. ? { extras: draft.tracing[currentIndex].extras }
  331. : {}),
  332. ...data,
  333. } as any
  334. }
  335. }))
  336. const newNodes = produce(nodes, (draft) => {
  337. const currentNode = draft.find(node => node.id === data.node_id)!
  338. currentNode.data._runningStatus = data.status as any
  339. })
  340. setNodes(newNodes)
  341. prevNodeId = data.node_id
  342. }
  343. if (onNodeFinished)
  344. onNodeFinished(params)
  345. },
  346. onIterationStart: (params) => {
  347. const { data } = params
  348. const {
  349. workflowRunningData,
  350. setWorkflowRunningData,
  351. } = workflowStore.getState()
  352. const {
  353. getNodes,
  354. setNodes,
  355. edges,
  356. setEdges,
  357. transform,
  358. } = store.getState()
  359. const nodes = getNodes()
  360. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  361. draft.tracing!.push({
  362. ...data,
  363. status: NodeRunningStatus.Running,
  364. details: [],
  365. } as any)
  366. }))
  367. const {
  368. setViewport,
  369. } = reactflow
  370. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  371. const currentNode = nodes[currentNodeIndex]
  372. const position = currentNode.position
  373. const zoom = transform[2]
  374. if (!currentNode.parentId) {
  375. setViewport({
  376. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  377. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  378. zoom: transform[2],
  379. })
  380. }
  381. const newNodes = produce(nodes, (draft) => {
  382. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  383. draft[currentNodeIndex].data._iterationLength = data.metadata.iterator_length
  384. })
  385. setNodes(newNodes)
  386. const newEdges = produce(edges, (draft) => {
  387. const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId)
  388. if (edge)
  389. edge.data = { ...edge.data, _run: true } as any
  390. })
  391. setEdges(newEdges)
  392. if (onIterationStart)
  393. onIterationStart(params)
  394. },
  395. onIterationNext: (params) => {
  396. const {
  397. workflowRunningData,
  398. setWorkflowRunningData,
  399. } = workflowStore.getState()
  400. const { data } = params
  401. const {
  402. getNodes,
  403. setNodes,
  404. } = store.getState()
  405. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  406. const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id)
  407. if (iteration) {
  408. if (iteration.details!.length >= iteration.metadata.iterator_length!)
  409. return
  410. }
  411. iteration?.details!.push([])
  412. }))
  413. const nodes = getNodes()
  414. const newNodes = produce(nodes, (draft) => {
  415. const currentNode = draft.find(node => node.id === data.node_id)!
  416. currentNode.data._iterationIndex = data.index > 0 ? data.index : 1
  417. })
  418. setNodes(newNodes)
  419. if (onIterationNext)
  420. onIterationNext(params)
  421. },
  422. onIterationFinish: (params) => {
  423. const { data } = params
  424. const {
  425. workflowRunningData,
  426. setWorkflowRunningData,
  427. } = workflowStore.getState()
  428. const {
  429. getNodes,
  430. setNodes,
  431. } = store.getState()
  432. const nodes = getNodes()
  433. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  434. const tracing = draft.tracing!
  435. const currIterationNode = tracing.find(trace => trace.node_id === data.node_id)
  436. if (currIterationNode) {
  437. Object.assign(currIterationNode, {
  438. ...data,
  439. status: NodeRunningStatus.Succeeded,
  440. })
  441. }
  442. }))
  443. const newNodes = produce(nodes, (draft) => {
  444. const currentNode = draft.find(node => node.id === data.node_id)!
  445. currentNode.data._runningStatus = data.status
  446. })
  447. setNodes(newNodes)
  448. prevNodeId = data.node_id
  449. if (onIterationFinish)
  450. onIterationFinish(params)
  451. },
  452. onParallelBranchStarted: (params) => {
  453. // console.log(params, 'parallel start')
  454. },
  455. onParallelBranchFinished: (params) => {
  456. // console.log(params, 'finished')
  457. },
  458. onTextChunk: (params) => {
  459. const { data: { text } } = params
  460. const {
  461. workflowRunningData,
  462. setWorkflowRunningData,
  463. } = workflowStore.getState()
  464. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  465. draft.resultTabActive = true
  466. draft.resultText += text
  467. }))
  468. },
  469. onTextReplace: (params) => {
  470. const { data: { text } } = params
  471. const {
  472. workflowRunningData,
  473. setWorkflowRunningData,
  474. } = workflowStore.getState()
  475. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  476. draft.resultText = text
  477. }))
  478. },
  479. onTTSChunk: (messageId: string, audio: string, audioType?: string) => {
  480. if (!audio || audio === '')
  481. return
  482. player.playAudioWithAudio(audio, true)
  483. AudioPlayerManager.getInstance().resetMsgId(messageId)
  484. },
  485. onTTSEnd: (messageId: string, audio: string, audioType?: string) => {
  486. player.playAudioWithAudio(audio, false)
  487. },
  488. ...restCallback,
  489. },
  490. )
  491. }, [store, reactflow, workflowStore, doSyncWorkflowDraft])
  492. const handleStopRun = useCallback((taskId: string) => {
  493. const appId = useAppStore.getState().appDetail?.id
  494. stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`)
  495. }, [])
  496. const handleRestoreFromPublishedWorkflow = useCallback(async () => {
  497. const appDetail = useAppStore.getState().appDetail
  498. const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`)
  499. if (publishedWorkflow) {
  500. const nodes = publishedWorkflow.graph.nodes
  501. const edges = publishedWorkflow.graph.edges
  502. const viewport = publishedWorkflow.graph.viewport!
  503. handleUpdateWorkflowCanvas({
  504. nodes,
  505. edges,
  506. viewport,
  507. })
  508. featuresStore?.setState({ features: publishedWorkflow.features })
  509. workflowStore.getState().setPublishedAt(publishedWorkflow.created_at)
  510. workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || [])
  511. }
  512. }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore])
  513. return {
  514. handleBackupDraft,
  515. handleLoadBackupDraft,
  516. handleRun,
  517. handleStopRun,
  518. handleRestoreFromPublishedWorkflow,
  519. }
  520. }