wwf
3 天以前 a430284aa21e3ae1f0d5654e55b2ad2852519cc2
app/components/workflow/nodes/_base/hooks/use-one-step-run.ts
@@ -1,4 +1,4 @@
import { useCallback, useEffect, useRef, useState } from 'react'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { unionBy } from 'lodash-es'
import produce from 'immer'
@@ -13,7 +13,7 @@
import { BlockEnum, InputVarType, NodeRunningStatus, VarType } from '@/app/components/workflow/types'
import { useStore as useAppStore } from '@/app/components/app/store'
import { useStore, useWorkflowStore } from '@/app/components/workflow/store'
import { getIterationSingleNodeRunUrl, getLoopSingleNodeRunUrl, singleNodeRun } from '@/service/workflow'
import { getIterationSingleNodeRunUrl, singleNodeRun } from '@/service/workflow'
import Toast from '@/app/components/base/toast'
import LLMDefault from '@/app/components/workflow/nodes/llm/default'
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
@@ -28,9 +28,8 @@
import ParameterExtractorDefault from '@/app/components/workflow/nodes/parameter-extractor/default'
import IterationDefault from '@/app/components/workflow/nodes/iteration/default'
import DocumentExtractorDefault from '@/app/components/workflow/nodes/document-extractor/default'
import LoopDefault from '@/app/components/workflow/nodes/loop/default'
import { ssePost } from '@/service/base'
import { noop } from 'lodash-es'
import { getInputVars as doGetInputVars } from '@/app/components/base/prompt-editor/constants'
import type { NodeTracing } from '@/types/workflow'
const { checkValid: checkLLMValid } = LLMDefault
@@ -46,9 +45,7 @@
const { checkValid: checkParameterExtractorValid } = ParameterExtractorDefault
const { checkValid: checkIterationValid } = IterationDefault
const { checkValid: checkDocumentExtractorValid } = DocumentExtractorDefault
const { checkValid: checkLoopValid } = LoopDefault
// eslint-disable-next-line ts/no-unsafe-function-type
const checkValidFns: Record<BlockEnum, Function> = {
  [BlockEnum.LLM]: checkLLMValid,
  [BlockEnum.KnowledgeRetrieval]: checkKnowledgeRetrievalValid,
@@ -63,7 +60,6 @@
  [BlockEnum.ParameterExtractor]: checkParameterExtractorValid,
  [BlockEnum.Iteration]: checkIterationValid,
  [BlockEnum.DocExtractor]: checkDocumentExtractorValid,
  [BlockEnum.Loop]: checkLoopValid,
} as any
type Params<T> = {
@@ -72,7 +68,6 @@
  defaultRunInputData: Record<string, any>
  moreDataForCheckValid?: any
  iteratorInputKey?: string
  loopInputKey?: string
}
const varTypeToInputVarType = (type: VarType, {
@@ -104,14 +99,12 @@
  defaultRunInputData,
  moreDataForCheckValid,
  iteratorInputKey,
  loopInputKey,
}: Params<T>) => {
  const { t } = useTranslation()
  const { getBeforeNodesInSameBranch, getBeforeNodesInSameBranchIncludeParent } = useWorkflow() as any
  const conversationVariables = useStore(s => s.conversationVariables)
  const isChatMode = useIsChatMode()
  const isIteration = data.type === BlockEnum.Iteration
  const isLoop = data.type === BlockEnum.Loop
  const availableNodes = getBeforeNodesInSameBranch(id)
  const availableNodesIncludeParent = getBeforeNodesInSameBranchIncludeParent(id)
@@ -145,20 +138,13 @@
  const checkValid = checkValidFns[data.type]
  const appId = useAppStore.getState().appDetail?.id
  const [runInputData, setRunInputData] = useState<Record<string, any>>(defaultRunInputData || {})
  const runInputDataRef = useRef(runInputData)
  const handleSetRunInputData = useCallback((data: Record<string, any>) => {
    runInputDataRef.current = data
    setRunInputData(data)
  }, [])
  const iterationTimes = iteratorInputKey ? runInputData[iteratorInputKey].length : 0
  const loopTimes = loopInputKey ? runInputData[loopInputKey].length : 0
  const [runResult, setRunResult] = useState<any>(null)
  const { handleNodeDataUpdate }: { handleNodeDataUpdate: (data: any) => void } = useNodeDataUpdate()
  const [canShowSingleRun, setCanShowSingleRun] = useState(false)
  const isShowSingleRun = data._isSingleRun && canShowSingleRun
  const [iterationRunResult, setIterationRunResult] = useState<NodeTracing[]>([])
  const [loopRunResult, setLoopRunResult] = useState<NodeTracing[]>([])
  const [iterationRunResult, setIterationRunResult] = useState<NodeTracing[][]>([])
  useEffect(() => {
    if (!checkValid) {
@@ -183,13 +169,13 @@
        })
      }
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [data._isSingleRun])
  const workflowStore = useWorkflowStore()
  useEffect(() => {
    workflowStore.getState().setShowSingleRunPanel(!!isShowSingleRun)
  }, [isShowSingleRun, workflowStore])
  }, [isShowSingleRun])
  const hideSingleRun = () => {
    handleNodeDataUpdate({
@@ -222,18 +208,19 @@
    })
    let res: any
    try {
      if (!isIteration && !isLoop) {
      if (!isIteration) {
        res = await singleNodeRun(appId!, id, { inputs: submitData }) as any
      }
      else if (isIteration) {
      else {
        setIterationRunResult([])
        let _iterationResult: NodeTracing[] = []
        let _iterationResult: NodeTracing[][] = []
        let _runResult: any = null
        ssePost(
          getIterationSingleNodeRunUrl(isChatMode, appId!, id),
          { body: { inputs: submitData } },
          {
            onWorkflowStarted: noop,
            onWorkflowStarted: () => {
            },
            onWorkflowFinished: (params) => {
              handleNodeDataUpdate({
                id,
@@ -246,43 +233,27 @@
              _runResult.created_by = iterationData.created_by.name
              setRunResult(_runResult)
            },
            onIterationStart: (params) => {
              const newIterationRunResult = produce(_iterationResult, (draft) => {
                draft.push({
                  ...params.data,
                  status: NodeRunningStatus.Running,
                })
              })
              _iterationResult = newIterationRunResult
              setIterationRunResult(newIterationRunResult)
            },
            onIterationNext: () => {
              // iteration next trigger time is triggered one more time than iterationTimes
              if (_iterationResult.length >= iterationTimes!)
                return _iterationResult.length >= iterationTimes!
                return
              const newIterationRunResult = produce(_iterationResult, (draft) => {
                draft.push([])
              })
              _iterationResult = newIterationRunResult
              setIterationRunResult(newIterationRunResult)
            },
            onIterationFinish: (params) => {
              _runResult = params.data
              setRunResult(_runResult)
              const iterationRunResult = _iterationResult
              const currentIndex = iterationRunResult.findIndex(trace => trace.id === params.data.id)
              const newIterationRunResult = produce(iterationRunResult, (draft) => {
                if (currentIndex > -1) {
                  draft[currentIndex] = {
                    ...draft[currentIndex],
                    ...data,
                  }
                }
              })
              _iterationResult = newIterationRunResult
              setIterationRunResult(newIterationRunResult)
            },
            onNodeStarted: (params) => {
              const newIterationRunResult = produce(_iterationResult, (draft) => {
                draft.push({
                draft[draft.length - 1].push({
                  ...params.data,
                  status: NodeRunningStatus.Running,
                })
                } as NodeTracing)
              })
              _iterationResult = newIterationRunResult
              setIterationRunResult(newIterationRunResult)
@@ -291,21 +262,14 @@
              const iterationRunResult = _iterationResult
              const { data } = params
              const currentIndex = iterationRunResult.findIndex(trace => trace.id === data.id)
              const currentIndex = iterationRunResult[iterationRunResult.length - 1].findIndex(trace => trace.node_id === data.node_id)
              const newIterationRunResult = produce(iterationRunResult, (draft) => {
                if (currentIndex > -1) {
                  draft[currentIndex] = {
                    ...draft[currentIndex],
                  draft[draft.length - 1][currentIndex] = {
                    ...data,
                  }
                    status: NodeRunningStatus.Succeeded,
                  } as NodeTracing
                }
              })
              _iterationResult = newIterationRunResult
              setIterationRunResult(newIterationRunResult)
            },
            onNodeRetry: (params) => {
              const newIterationRunResult = produce(_iterationResult, (draft) => {
                draft.push(params.data)
              })
              _iterationResult = newIterationRunResult
              setIterationRunResult(newIterationRunResult)
@@ -322,110 +286,11 @@
          },
        )
      }
      else if (isLoop) {
        setLoopRunResult([])
        let _loopResult: NodeTracing[] = []
        let _runResult: any = null
        ssePost(
          getLoopSingleNodeRunUrl(isChatMode, appId!, id),
          { body: { inputs: submitData } },
          {
            onWorkflowStarted: noop,
            onWorkflowFinished: (params) => {
              handleNodeDataUpdate({
                id,
                data: {
                  ...data,
                  _singleRunningStatus: NodeRunningStatus.Succeeded,
                },
              })
              const { data: loopData } = params
              _runResult.created_by = loopData.created_by.name
              setRunResult(_runResult)
            },
            onLoopStart: (params) => {
              const newLoopRunResult = produce(_loopResult, (draft) => {
                draft.push({
                  ...params.data,
                  status: NodeRunningStatus.Running,
                })
              })
              _loopResult = newLoopRunResult
              setLoopRunResult(newLoopRunResult)
            },
            onLoopNext: () => {
              // loop next trigger time is triggered one more time than loopTimes
              if (_loopResult.length >= loopTimes!)
                return _loopResult.length >= loopTimes!
            },
            onLoopFinish: (params) => {
              _runResult = params.data
              setRunResult(_runResult)
              const loopRunResult = _loopResult
              const currentIndex = loopRunResult.findIndex(trace => trace.id === params.data.id)
              const newLoopRunResult = produce(loopRunResult, (draft) => {
                if (currentIndex > -1) {
                  draft[currentIndex] = {
                    ...draft[currentIndex],
                    ...data,
                  }
                }
              })
              _loopResult = newLoopRunResult
              setLoopRunResult(newLoopRunResult)
            },
            onNodeStarted: (params) => {
              const newLoopRunResult = produce(_loopResult, (draft) => {
                draft.push({
                  ...params.data,
                  status: NodeRunningStatus.Running,
                })
              })
              _loopResult = newLoopRunResult
              setLoopRunResult(newLoopRunResult)
            },
            onNodeFinished: (params) => {
              const loopRunResult = _loopResult
              const { data } = params
              const currentIndex = loopRunResult.findIndex(trace => trace.id === data.id)
              const newLoopRunResult = produce(loopRunResult, (draft) => {
                if (currentIndex > -1) {
                  draft[currentIndex] = {
                    ...draft[currentIndex],
                    ...data,
                  }
                }
              })
              _loopResult = newLoopRunResult
              setLoopRunResult(newLoopRunResult)
            },
            onNodeRetry: (params) => {
              const newLoopRunResult = produce(_loopResult, (draft) => {
                draft.push(params.data)
              })
              _loopResult = newLoopRunResult
              setLoopRunResult(newLoopRunResult)
            },
            onError: () => {
              handleNodeDataUpdate({
                id,
                data: {
                  ...data,
                  _singleRunningStatus: NodeRunningStatus.Failed,
                },
              })
            },
          },
        )
      }
      if (res && res.error)
      if (res.error)
        throw new Error(res.error)
    }
    catch (e: any) {
      console.error(e)
      if (!isIteration && !isLoop) {
      if (!isIteration) {
        handleNodeDataUpdate({
          id,
          data: {
@@ -437,7 +302,7 @@
      }
    }
    finally {
      if (!isIteration && !isLoop) {
      if (!isIteration) {
        setRunResult({
          ...res,
          total_tokens: res.execution_metadata?.total_tokens || 0,
@@ -445,7 +310,7 @@
        })
      }
    }
    if (!isIteration && !isLoop) {
    if (!isIteration) {
      handleNodeDataUpdate({
        id,
        data: {
@@ -532,11 +397,9 @@
    handleRun,
    handleStop,
    runInputData,
    runInputDataRef,
    setRunInputData: handleSetRunInputData,
    setRunInputData,
    runResult,
    iterationRunResult,
    loopRunResult,
  }
}