import React, { useCallback } from "react"
import { aichemyProtoAxios } from 'API/mmAxios'
import { useAppStoreKey } from 'AppStore'
import {
  ioIsDefined,
  poll,
  checkTaskStatus,
  getTaskResults,
  outputContainsString,
  outputIsCatEncode,
  combineModelLists
} from "../utils";

const TaskStatusMap = {
  "complete#success": "Finished",
  "complete#error": "Error",
  "complete#killed": "Killed",
  "running": "Running",
  "queued": "Queued"
}

export default function useTrainModel({ enqueueSnackbar }) {
  const [workflow, setWorkflow] = useAppStoreKey("Workflow")
  const [trainingState, setTrainingState] = useAppStoreKey("MS/Training/JobState")
  const [errorState, setErrorState] = useAppStoreKey("MS/Training/ErrorState")
  const [currentTab, setCurrentTab] = useAppStoreKey("MS/Training/CurrentTab")
  const [visibleTabs,] = useAppStoreKey("MS/Training/VisibleTabs")
  const [MSState, setMSState] = useAppStoreKey("MixingStudioState")
  const [stopTaskIsLoading, setStopTaskIsLoading] = React.useState(false);
  const autoHideDuration = 10000  // 10 seconds

  const TrainModel = useCallback((modelParams, modelName) => () => {
    if (!ioIsDefined(workflow)) {
      enqueueSnackbar && enqueueSnackbar('Please set input and output columns before model training',
        { variant: "warning", autoHideDuration: autoHideDuration })
      return
    }

    if (outputIsCatEncode(workflow)) {
      // The RFC accepts categorical values as output. So the output columns should not be encoded.
      enqueueSnackbar && enqueueSnackbar('No need to categorical encode output columns. Please revert the encoding.',
        { variant: "warning", autoHideDuration: autoHideDuration })
      return
    }

    if (outputContainsString(workflow) && !modelName.toLowerCase().includes("classification")) {
      // If the output contains string, the model should be a classification model.
      enqueueSnackbar && enqueueSnackbar('Please set output column to numerical values or using a classification model',
        { variant: "warning", autoHideDuration: autoHideDuration })
      return
    }


    setMSState({ ...MSState, enableTraining: false })
    // actual post
    let url = 'tasks'
    let task_body = JSON.stringify({
      workflow_id: workflow.uuid,
      title: modelName,
      operation: 'train',
      operation_args: JSON.stringify({
        operations_dict: {
          model_training: modelParams
        }
      })
    })
    let config = {
      headers: { "Content-Type": "application/json; charset=utf-8" },
    }
    aichemyProtoAxios.post(url, task_body, config)
      .then(res => {
        let task_id = res.data.task_id
        let newTrainingJob = {
          task_id: task_id,
          status: TaskStatusMap["queued"],
          info: {
            name: modelName,
            description: null
          },
          operation: 'train'
        }
        if (trainingState === undefined) {
          setTrainingState([newTrainingJob])
        } else {
          setTrainingState([...trainingState, newTrainingJob])
        }
        if (visibleTabs) {
          setCurrentTab(visibleTabs[visibleTabs.length - 1] + 1)
        } else {
          setCurrentTab(0)
        }
        enqueueSnackbar && enqueueSnackbar('Job submitted to queue', { variant: "info" })
      })
      .catch(err => {
        console.error(err)
        if (err.response?.status === 409) {// conflict code, existing job exits
          enqueueSnackbar && enqueueSnackbar('Training job with these parameters already queued/running', { variant: "error" })
        } else {
          enqueueSnackbar && enqueueSnackbar('Error during job creation.', { variant: "error" })
        }
      })
      .finally(() => {
        setMSState({ ...MSState, enableTraining: true })
      })
  }, [workflow, trainingState, enqueueSnackbar, setTrainingState, MSState, setMSState, setCurrentTab, visibleTabs])


  const RemoveErrorState = (task_id) => {
    setErrorState(errorState.filter((state) => state.task_id !== task_id))
  }

  const setError = (task_id, error) => {
    if (trainingState === undefined || trainingState.length === 0) return
    let jobState = trainingState.filter((state) => state.task_id === task_id)[0]
    // retain bad training job details
    if (jobState?.operation === "train") {
      let newErrorState = { ...jobState }
      newErrorState.error = error
      setErrorState((errorState === undefined ? [] : errorState).concat(newErrorState))
    }
    // remove the training result due to the error
    setTrainingState(trainingState.filter((state) => state.task_id !== task_id))
  }

  const updateStatus = useCallback((task_id, status) => {
    if (trainingState === undefined) return
    let newState = [...trainingState]
    newState.forEach((state) => {
      if (state.task_id === task_id) {
        state.status = status
      }
    })
    setTrainingState(newState)
  }, [trainingState, setTrainingState])

  const GetTaskStatus = useCallback((task_id) => {
    if (trainingState === undefined) return undefined

    let task = trainingState.filter(s => s.task_id === task_id)[0]
    return task?.status
  }, [trainingState])

  const removeTaskFromState = (task_ids) => {
    setTrainingState(trainingState.filter((data) => !task_ids.includes(data.task_id)))
  }

  const handleTaskSuccess = (task_id, onSuccess, onError) => {
    const currentModels = [...workflow.models]
    let wf = { ...workflow }
    getTaskResults(task_id).then(res => {
      // update the workflow
      wf = res.data.data
      wf.models = combineModelLists(currentModels, wf.models)
      setWorkflow(wf)
      onSuccess && onSuccess(wf)
      let task_ids = wf.models.map((model) => model.BE_task_id)
      task_ids.push(task_id)
      removeTaskFromState(task_ids)

    }).catch((err) => {
      console.error(err)
      enqueueSnackbar && enqueueSnackbar('Error during retrieving results.', { variant: "error" })
      setError(task_id, err.message)
      onError && onError()
    })
  }

  const handleTaskError = (task_id, onError) => {
    getTaskResults(task_id).then(res => {
      const err_msg = res.data.data.error
      console.error(err_msg)
      setError(task_id, err_msg)
      enqueueSnackbar && enqueueSnackbar('Task failed during execution.', { variant: "error" })
    }).catch((err) => {
      console.error(err)
      enqueueSnackbar && enqueueSnackbar('Task failed and failed to log the underlying issue...', { variant: "error" })
      setError(task_id, err.message)
    }).finally(() => {
      onError && onError()
    })
  }

  const PollTask = (task_id, mountedRef, onStatusCheck, onSuccess, onError) => {
    let status = ''
    let interval = 1000 // ms
    // set initial status
    let initialStatus = GetTaskStatus(task_id)
    onStatusCheck && onStatusCheck(initialStatus)
    poll(
      checkTaskStatus,
      (res) => {
        if (!mountedRef.current) {
          return true
        }
        status = res.data.status
        if (GetTaskStatus(task_id) !== TaskStatusMap[status]) {
          updateStatus(task_id, TaskStatusMap[status])
          onStatusCheck && onStatusCheck(TaskStatusMap[status])
        }
        return res ? status.startsWith("complete") : false
      },
      interval,
      task_id
    ).then((res) => {
      status = res.data.status
    }).catch((err) => {
      updateStatus(task_id, "Error")
      console.error(err)
    }).finally(() => {
      if (status === "complete#success") {
        handleTaskSuccess(task_id, onSuccess, onError)
      } else if (status === "complete#error") {
        handleTaskError(task_id, onError)
      } else if (status === "complete#killed") {
        // TODO: do we need any specific handling for this we want to expose?
        removeTaskFromState(task_id)
      }
    })
  }

  const StopTask = useCallback((task_id) => {
    setStopTaskIsLoading(true)
    let url = `tasks/${task_id}/kill`
    let config = {
      headers: { "Content-Type": "application/json; charset=utf-8" },
    }
    aichemyProtoAxios.post(url, {}, config)
      .then(() => {
        setTrainingState(trainingState.filter(s => s.task_id !== task_id))
        // get nearest tab to visible tab that was stopped
        if (visibleTabs === undefined) {
          setCurrentTab(undefined)
        } else {
          let tabs = visibleTabs.filter(t => t < currentTab)
          if (tabs.length === 0) {
            setCurrentTab(undefined)
          } else {
            setCurrentTab(tabs[tabs.length - 1])
          }
        }
        setStopTaskIsLoading(false)
        enqueueSnackbar && enqueueSnackbar('Task successfully stopped.', { variant: "info" })
      })
      .catch(err => console.log(err))
  }, [enqueueSnackbar, trainingState, setTrainingState, setCurrentTab, currentTab, visibleTabs])


  const GetShapTaskByModelId = useCallback((modelIdx) => {
    if (trainingState === undefined) return undefined

    let state = trainingState.filter(s => s.operation === 'shap' && s?.info?.model_idx === modelIdx)[0]
    return state
  }, [trainingState])

  const SubmitShap = useCallback((modelIdx) => {
    const model_params = {
      operations_dict: {
        model_post_process: {
          ComputeShap: {
            kwargs: { model_id: workflow.models[modelIdx].uuid }
          }
        }
      },
    }

    let url = 'tasks'
    let task_body = JSON.stringify({
      workflow_id: workflow.uuid,
      operation: 'update',
      operation_args: JSON.stringify(model_params),
      title: modelIdx.toString()
    })
    let config = {
      headers: { "Content-Type": "application/json; charset=utf-8" },
    }
    aichemyProtoAxios.post(url, task_body, config)
      .then(res => {
        let task_id = res.data.task_id
        let newJob = {
          task_id: task_id,
          status: TaskStatusMap["queued"],
          info: {
            model_idx: modelIdx,
            model_uuid: workflow.models[modelIdx].uuid
          },
          operation: 'shap',
        }
        if (trainingState === undefined) {
          setTrainingState([newJob])
        } else {
          setTrainingState([...trainingState, newJob])
        }
      })
      .catch(err => console.error(err))
  }, [workflow, trainingState, setTrainingState])


  const ReloadActiveTrainingTasks = useCallback(() => {
    let url = `workflow/${workflow.uuid}/tasks?operation=train&status=current`
    let config = {
      headers: { "Content-Type": "application/json; charset=utf-8" },
    }
    return aichemyProtoAxios.get(url, config)
      .then(res => {
        let tasks = res.data
        let trainingState = tasks.map(task => {
          return {
            task_id: task.task_id,
            status: task.status,
            info: {
              name: task.title,
              description: null
            },
            operation: task.operation
          }
        })
        setTrainingState(trainingState)
        return trainingState
      })
      .catch(err => console.log(err))
  }, [workflow, setTrainingState])

  const ReloadActiveShap = useCallback((otherJobs) => {
    let url = `workflow/${workflow.uuid}/tasks?operation=update&status=current`
    let config = {
      headers: { "Content-Type": "application/json; charset=utf-8" },
    }
    return aichemyProtoAxios.get(url, config)
      .then(res => {
        let tasks = res.data
        let shapJobs = tasks.map(task => {
          // get model idx
          let modelIdx = Number(task.title)
          return {
            task_id: task.task_id,
            status: task.status,
            info: {
              model_idx: modelIdx,
              model_uuid: workflow.models[modelIdx].uuid
            },
            operation: 'shap'
          }
        })
        if (otherJobs === undefined) {
          setTrainingState([shapJobs])
        } else {
          setTrainingState(otherJobs.concat(shapJobs))
        }
      })
      .catch(err => console.log(err))
  }, [workflow, setTrainingState])

  const ReloadTasks = () => {
    ReloadActiveTrainingTasks()
      .then((res) => {
        ReloadActiveShap(res)
      })
  }

  return {
    trainingState,
    PollTask,
    StopTask,
    GetTaskStatus,
    ReloadTasks,
    stopTaskIsLoading,
    TrainModel,
    SubmitShap,
    GetShapTaskByModelId,
    errorState,
    RemoveErrorState
  }
}