utils.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import {
  2. Position,
  3. getConnectedEdges,
  4. getOutgoers,
  5. } from 'reactflow'
  6. import dagre from '@dagrejs/dagre'
  7. import { v4 as uuid4 } from 'uuid'
  8. import {
  9. cloneDeep,
  10. uniqBy,
  11. } from 'lodash-es'
  12. import type {
  13. Edge,
  14. InputVar,
  15. Node,
  16. ToolWithProvider,
  17. } from './types'
  18. import { BlockEnum } from './types'
  19. import {
  20. ITERATION_NODE_Z_INDEX,
  21. NODE_WIDTH_X_OFFSET,
  22. START_INITIAL_POSITION,
  23. } from './constants'
  24. import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
  25. import type { ToolNodeType } from './nodes/tool/types'
  26. import { CollectionType } from '@/app/components/tools/types'
  27. import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
  28. const WHITE = 'WHITE'
  29. const GRAY = 'GRAY'
  30. const BLACK = 'BLACK'
  31. const isCyclicUtil = (nodeId: string, color: Record<string, string>, adjaList: Record<string, string[]>, stack: string[]) => {
  32. color[nodeId] = GRAY
  33. stack.push(nodeId)
  34. for (let i = 0; i < adjaList[nodeId].length; ++i) {
  35. const childId = adjaList[nodeId][i]
  36. if (color[childId] === GRAY) {
  37. stack.push(childId)
  38. return true
  39. }
  40. if (color[childId] === WHITE && isCyclicUtil(childId, color, adjaList, stack))
  41. return true
  42. }
  43. color[nodeId] = BLACK
  44. if (stack.length > 0 && stack[stack.length - 1] === nodeId)
  45. stack.pop()
  46. return false
  47. }
  48. const getCycleEdges = (nodes: Node[], edges: Edge[]) => {
  49. const adjaList: Record<string, string[]> = {}
  50. const color: Record<string, string> = {}
  51. const stack: string[] = []
  52. for (const node of nodes) {
  53. color[node.id] = WHITE
  54. adjaList[node.id] = []
  55. }
  56. for (const edge of edges)
  57. adjaList[edge.source]?.push(edge.target)
  58. for (let i = 0; i < nodes.length; i++) {
  59. if (color[nodes[i].id] === WHITE)
  60. isCyclicUtil(nodes[i].id, color, adjaList, stack)
  61. }
  62. const cycleEdges = []
  63. if (stack.length > 0) {
  64. const cycleNodes = new Set(stack)
  65. for (const edge of edges) {
  66. if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target))
  67. cycleEdges.push(edge)
  68. }
  69. }
  70. return cycleEdges
  71. }
  72. export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
  73. const nodes = cloneDeep(originNodes)
  74. const edges = cloneDeep(originEdges)
  75. const firstNode = nodes[0]
  76. if (!firstNode?.position) {
  77. nodes.forEach((node, index) => {
  78. node.position = {
  79. x: START_INITIAL_POSITION.x + index * NODE_WIDTH_X_OFFSET,
  80. y: START_INITIAL_POSITION.y,
  81. }
  82. })
  83. }
  84. const iterationNodeMap = nodes.reduce((acc, node) => {
  85. if (node.parentId) {
  86. if (acc[node.parentId])
  87. acc[node.parentId].push(node.id)
  88. else
  89. acc[node.parentId] = [node.id]
  90. }
  91. return acc
  92. }, {} as Record<string, string[]>)
  93. return nodes.map((node) => {
  94. node.type = 'custom'
  95. const connectedEdges = getConnectedEdges([node], edges)
  96. node.data._connectedSourceHandleIds = connectedEdges.filter(edge => edge.source === node.id).map(edge => edge.sourceHandle || 'source')
  97. node.data._connectedTargetHandleIds = connectedEdges.filter(edge => edge.target === node.id).map(edge => edge.targetHandle || 'target')
  98. if (node.data.type === BlockEnum.IfElse) {
  99. node.data._targetBranches = [
  100. {
  101. id: 'true',
  102. name: 'IS TRUE',
  103. },
  104. {
  105. id: 'false',
  106. name: 'IS FALSE',
  107. },
  108. ]
  109. }
  110. if (node.data.type === BlockEnum.QuestionClassifier) {
  111. node.data._targetBranches = (node.data as QuestionClassifierNodeType).classes.map((topic) => {
  112. return topic
  113. })
  114. }
  115. if (node.data.type === BlockEnum.Iteration)
  116. node.data._children = iterationNodeMap[node.id] || []
  117. return node
  118. })
  119. }
  120. export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => {
  121. const nodes = cloneDeep(originNodes)
  122. const edges = cloneDeep(originEdges)
  123. let selectedNode: Node | null = null
  124. const nodesMap = nodes.reduce((acc, node) => {
  125. acc[node.id] = node
  126. if (node.data?.selected)
  127. selectedNode = node
  128. return acc
  129. }, {} as Record<string, Node>)
  130. const cycleEdges = getCycleEdges(nodes, edges)
  131. return edges.filter((edge) => {
  132. return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target)
  133. }).map((edge) => {
  134. edge.type = 'custom'
  135. if (!edge.sourceHandle)
  136. edge.sourceHandle = 'source'
  137. if (!edge.targetHandle)
  138. edge.targetHandle = 'target'
  139. if (!edge.data?.sourceType && edge.source && nodesMap[edge.source]) {
  140. edge.data = {
  141. ...edge.data,
  142. sourceType: nodesMap[edge.source].data.type!,
  143. } as any
  144. }
  145. if (!edge.data?.targetType && edge.target && nodesMap[edge.target]) {
  146. edge.data = {
  147. ...edge.data,
  148. targetType: nodesMap[edge.target].data.type!,
  149. } as any
  150. }
  151. if (selectedNode) {
  152. edge.data = {
  153. ...edge.data,
  154. _connectedNodeIsSelected: edge.source === selectedNode.id || edge.target === selectedNode.id,
  155. } as any
  156. }
  157. return edge
  158. })
  159. }
  160. export const getLayoutByDagre = (originNodes: Node[], originEdges: Edge[]) => {
  161. const dagreGraph = new dagre.graphlib.Graph()
  162. dagreGraph.setDefaultEdgeLabel(() => ({}))
  163. const nodes = cloneDeep(originNodes).filter(node => !node.parentId)
  164. const edges = cloneDeep(originEdges).filter(edge => !edge.data?.isInIteration)
  165. dagreGraph.setGraph({
  166. rankdir: 'LR',
  167. align: 'UL',
  168. nodesep: 40,
  169. ranksep: 60,
  170. ranker: 'tight-tree',
  171. marginx: 30,
  172. marginy: 200,
  173. })
  174. nodes.forEach((node) => {
  175. dagreGraph.setNode(node.id, {
  176. width: node.width!,
  177. height: node.height!,
  178. })
  179. })
  180. edges.forEach((edge) => {
  181. dagreGraph.setEdge(edge.source, edge.target)
  182. })
  183. dagre.layout(dagreGraph)
  184. return dagreGraph
  185. }
  186. export const canRunBySingle = (nodeType: BlockEnum) => {
  187. return nodeType === BlockEnum.LLM
  188. || nodeType === BlockEnum.KnowledgeRetrieval
  189. || nodeType === BlockEnum.Code
  190. || nodeType === BlockEnum.TemplateTransform
  191. || nodeType === BlockEnum.QuestionClassifier
  192. || nodeType === BlockEnum.HttpRequest
  193. || nodeType === BlockEnum.Tool
  194. || nodeType === BlockEnum.ParameterExtractor
  195. || nodeType === BlockEnum.Iteration
  196. }
  197. type ConnectedSourceOrTargetNodesChange = {
  198. type: string
  199. edge: Edge
  200. }[]
  201. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  202. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  203. changes.forEach((change) => {
  204. const {
  205. edge,
  206. type,
  207. } = change
  208. const sourceNode = nodes.find(node => node.id === edge.source)!
  209. if (sourceNode) {
  210. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  211. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  212. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  213. }
  214. }
  215. const targetNode = nodes.find(node => node.id === edge.target)!
  216. if (targetNode) {
  217. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  218. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  219. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  220. }
  221. }
  222. if (sourceNode) {
  223. if (type === 'remove') {
  224. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  225. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  226. }
  227. if (type === 'add')
  228. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  229. }
  230. if (targetNode) {
  231. if (type === 'remove') {
  232. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  233. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  234. }
  235. if (type === 'add')
  236. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  237. }
  238. })
  239. return nodesConnectedSourceOrTargetHandleIdsMap
  240. }
  241. export const generateNewNode = ({ data, position, id, zIndex, ...rest }: Omit<Node, 'id'> & { id?: string }) => {
  242. return {
  243. id: id || `${Date.now()}`,
  244. type: 'custom',
  245. data,
  246. position,
  247. targetPosition: Position.Left,
  248. sourcePosition: Position.Right,
  249. zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex,
  250. ...rest,
  251. } as Node
  252. }
  253. export const genNewNodeTitleFromOld = (oldTitle: string) => {
  254. const regex = /^(.+?)\s*\((\d+)\)\s*$/
  255. const match = oldTitle.match(regex)
  256. if (match) {
  257. const title = match[1]
  258. const num = parseInt(match[2], 10)
  259. return `${title} (${num + 1})`
  260. }
  261. else {
  262. return `${oldTitle} (1)`
  263. }
  264. }
  265. export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
  266. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  267. if (!startNode) {
  268. return {
  269. validNodes: [],
  270. maxDepth: 0,
  271. }
  272. }
  273. const list: Node[] = [startNode]
  274. let maxDepth = 1
  275. const traverse = (root: Node, depth: number) => {
  276. if (depth > maxDepth)
  277. maxDepth = depth
  278. const outgoers = getOutgoers(root, nodes, edges)
  279. if (outgoers.length) {
  280. outgoers.forEach((outgoer) => {
  281. list.push(outgoer)
  282. if (outgoer.data.type === BlockEnum.Iteration)
  283. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  284. traverse(outgoer, depth + 1)
  285. })
  286. }
  287. else {
  288. list.push(root)
  289. if (root.data.type === BlockEnum.Iteration)
  290. list.push(...nodes.filter(node => node.parentId === root.id))
  291. }
  292. }
  293. traverse(startNode, maxDepth)
  294. return {
  295. validNodes: uniqBy(list, 'id'),
  296. maxDepth,
  297. }
  298. }
  299. export const getToolCheckParams = (
  300. toolData: ToolNodeType,
  301. buildInTools: ToolWithProvider[],
  302. customTools: ToolWithProvider[],
  303. workflowTools: ToolWithProvider[],
  304. language: string,
  305. ) => {
  306. const { provider_id, provider_type, tool_name } = toolData
  307. const isBuiltIn = provider_type === CollectionType.builtIn
  308. const currentTools = provider_type === CollectionType.builtIn ? buildInTools : provider_type === CollectionType.custom ? customTools : workflowTools
  309. const currCollection = currentTools.find(item => item.id === provider_id)
  310. const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
  311. const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
  312. const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
  313. const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
  314. return {
  315. toolInputsSchema: (() => {
  316. const formInputs: InputVar[] = []
  317. toolInputVarSchema.forEach((item: any) => {
  318. formInputs.push({
  319. label: item.label[language] || item.label.en_US,
  320. variable: item.variable,
  321. type: item.type,
  322. required: item.required,
  323. })
  324. })
  325. return formInputs
  326. })(),
  327. notAuthed: isBuiltIn && !!currCollection?.allow_delete && !currCollection?.is_team_authorization,
  328. toolSettingSchema,
  329. language,
  330. }
  331. }
  332. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  333. const idMap = nodes.reduce((acc, node) => {
  334. acc[node.id] = uuid4()
  335. return acc
  336. }, {} as Record<string, string>)
  337. const newNodes = nodes.map((node) => {
  338. return {
  339. ...node,
  340. id: idMap[node.id],
  341. }
  342. })
  343. const newEdges = edges.map((edge) => {
  344. return {
  345. ...edge,
  346. source: idMap[edge.source],
  347. target: idMap[edge.target],
  348. }
  349. })
  350. return [newNodes, newEdges] as [Node[], Edge[]]
  351. }
  352. export const isMac = () => {
  353. return navigator.userAgent.toUpperCase().includes('MAC')
  354. }
  355. const specialKeysNameMap: Record<string, string | undefined> = {
  356. ctrl: '⌘',
  357. alt: '⌥',
  358. }
  359. export const getKeyboardKeyNameBySystem = (key: string) => {
  360. if (isMac())
  361. return specialKeysNameMap[key] || key
  362. return key
  363. }
  364. const specialKeysCodeMap: Record<string, string | undefined> = {
  365. ctrl: 'meta',
  366. }
  367. export const getKeyboardKeyCodeBySystem = (key: string) => {
  368. if (isMac())
  369. return specialKeysCodeMap[key] || key
  370. return key
  371. }
  372. export const getTopLeftNodePosition = (nodes: Node[]) => {
  373. let minX = Infinity
  374. let minY = Infinity
  375. nodes.forEach((node) => {
  376. if (node.position.x < minX)
  377. minX = node.position.x
  378. if (node.position.y < minY)
  379. minY = node.position.y
  380. })
  381. return {
  382. x: minX,
  383. y: minY,
  384. }
  385. }
  386. export const isEventTargetInputArea = (target: HTMLElement) => {
  387. if (target.tagName === 'INPUT' || target.tagName === 'TEXTAREA')
  388. return true
  389. if (target.contentEditable === 'true')
  390. return true
  391. }