default.ts 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import type { NodeDefault } from '../../types'
  2. import { BlockEnum } from '../../types'
  3. import type { QuestionClassifierNodeType } from './types'
  4. import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/blocks'
  5. const i18nPrefix = 'workflow'
  6. const nodeDefault: NodeDefault<QuestionClassifierNodeType> = {
  7. defaultValue: {
  8. query_variable_selector: [],
  9. model: {
  10. provider: '',
  11. name: '',
  12. mode: 'chat',
  13. completion_params: {
  14. temperature: 0.7,
  15. },
  16. },
  17. classes: [
  18. {
  19. id: '1',
  20. name: '',
  21. },
  22. {
  23. id: '2',
  24. name: '',
  25. },
  26. ],
  27. _targetBranches: [
  28. {
  29. id: '1',
  30. name: '',
  31. },
  32. {
  33. id: '2',
  34. name: '',
  35. },
  36. ],
  37. vision: {
  38. enabled: false,
  39. },
  40. },
  41. getAvailablePrevNodes(isChatMode: boolean) {
  42. const nodes = isChatMode
  43. ? ALL_CHAT_AVAILABLE_BLOCKS
  44. : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End)
  45. return nodes
  46. },
  47. getAvailableNextNodes(isChatMode: boolean) {
  48. const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS
  49. return nodes
  50. },
  51. checkValid(payload: QuestionClassifierNodeType, t: any) {
  52. let errorMessages = ''
  53. if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0))
  54. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.questionClassifiers.inputVars`) })
  55. if (!errorMessages && !payload.model.provider)
  56. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.questionClassifiers.model`) })
  57. if (!errorMessages && (!payload.classes || payload.classes.length === 0))
  58. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.questionClassifiers.class`) })
  59. if (!errorMessages && (payload.classes.some(item => !item.name)))
  60. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.questionClassifiers.topicName`) })
  61. if (!errorMessages && payload.vision?.enabled && !payload.vision.configs?.variable_selector?.length)
  62. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.visionVariable`) })
  63. return {
  64. isValid: !errorMessages,
  65. errorMessage: errorMessages,
  66. }
  67. },
  68. }
  69. export default nodeDefault