import React, { useEffect, useCallback, useRef, useState } from "react"
import Button from "@material-ui/core/Button";
import DescriptionIcon from "@material-ui/icons/Description"
import Typography from "@material-ui/core/Typography"
import CircularProgress from "@material-ui/core/CircularProgress"
import { useSnackbar } from "notistack"
import { CSVLink } from "react-csv";
import HotTable from "../../HoT/StyledHoT"
import { useAppStoreKey } from "../../../AppStore"
import { BatchPrediction } from "./AichemyPredict"
import useModelData from './useModelData'
import AddIcon from '@material-ui/icons/Add';
import ButtonWithSpinner from "../../General/ButtonWithSpinner";

export default function BatchPredictionView({ classes, cellStyle, workflow_id, modelID }) {
  const [workflow,] = useAppStoreKey("Workflow");
  const { OriginalColumnOrdering } = useModelData({ modelID: modelID });
  const [stashedData, setStashedData] = useAppStoreKey("MS/Predict/Batch")
  const [batchData, setBatchData] = useState(undefined); // batch prediction result
  const [, setBatchShap] = useState(undefined); // batch prediction shap result
  const [csvData, setCSVData] = useState({ data: [], filename: 'download.csv' })
  const [isLoading, setIsLoading] = useState(false);
  const { enqueueSnackbar } = useSnackbar();
  const [nRows, setNRows] = useState(20);
  let hotTableComponent = React.useRef();
  const csvLink = useRef();

  const initDataDict = useCallback((predictionData) => {
    let inputColNames = [...OriginalColumnOrdering.input]
    let outputColNames = [...OriginalColumnOrdering.output]
    let newBatchData = [[]]
    let newBatchShap = {}
    inputColNames.forEach(() => {
      newBatchData[0].push('')
    })
    outputColNames.forEach(() => {
      newBatchData[0].push('')
    })
    if (predictionData !== undefined) {
      if (predictionData.batch !== undefined) {
        newBatchData = [...predictionData.batch]
      }
      // if (predictionData.batchShap !== undefined) {
      // newBatchShap = [...predictionData.batchShap]
      // }
    }
    return {
      batchData: newBatchData,
      batchShap: newBatchShap
    }
  }, [OriginalColumnOrdering])

  useEffect(() => {
    if (batchData === undefined) {
      let data = initDataDict(stashedData)
      setBatchData(data.batchData)
      setBatchShap(data.batchShap)
      setStashedData({
        batch: data.batchData,
        batchShap: data.batchShap
      })
    } else if (stashedData === undefined) {
      setStashedData(
        {
          batch: undefined,
          batchShap: undefined
        }
      )
    }
  }, [initDataDict, stashedData, setStashedData, batchData])


  const getAllHeaders = useCallback(() => {
    let input = [...OriginalColumnOrdering.input]
    let output = [...OriginalColumnOrdering.output]
    return input.concat(output)
  }, [OriginalColumnOrdering])

  const updateBatchInput = (changes) => {
    let updateData = [...batchData].map((row) => [...row])
    changes.forEach(change => updateData[change[0]][change[1]] = change[3])
    let affectedRows = changes.map(change => change[0])
    // store the changes
    setBatchData(updateData)
    setStashedData({ ...stashedData, batch: updateData })

    // find the relevant rows to rerun
    const nInput = OriginalColumnOrdering.input.length
    let inputData = updateData.map((row, rowIndex) => {
      // row wasnt modified, dont update anything
      if (!affectedRows.includes(rowIndex)) {
        return ""
      }
      // convert empty cell to null & remove partially filled line
      let nData = row.filter((val, idx) => val !== "" && val !== null && idx < nInput).length
      if (nData !== nInput) {
        return ""
      }
      // null out empty cells
      row = row.map(item => { return item !== "" ? item : null })
      return row
    })
    return inputData
  }

  const getBatchInput = () => {
    let updateData = [...batchData].map((row) => [...row])
    const nInput = OriginalColumnOrdering.input.length
    let inputData = updateData.map((row) => {
      // convert empty cell to null & remove partially filled line
      let nData = row.filter((val, idx) => val !== "" && val !== null && idx < nInput).length
      if (nData !== nInput) {
        return ""
      }
      // null out empty cells
      row = row.map(item => { return item !== "" ? item : null })
      return row
    }).filter(row => row !== "")

    return inputData
  }


  function runBatchPrediction(changes) {
    let inputData
    if (changes === undefined) {
      inputData = getBatchInput()
      console.log('inputData: ', inputData)
    } else {
      inputData = updateBatchInput(changes)
    }
    const onPostSuccess = () => {
      enqueueSnackbar("Task Staged for Execution", { variant: "success" })
    }
    const onPostFailure = (err) => {
      enqueueSnackbar("Error creating task (see console for details)", { variant: "error" })
      console.error(err)
      setIsLoading(false)
    }

    const onPollingSuccess = (res) => {
      setIsLoading(false)

      let batchHotData = hotTableComponent.current.hotInstance.getData();
      let updatedBatchData = [...batchHotData].map(row => [...row])
      const nInput = OriginalColumnOrdering.input.length

      // overwrite table data with result
      let { data, remap, shap } = res
      Object.keys(remap).forEach((predRow, index) => {
        let originalRow = remap[predRow]
        data[index].forEach((value, columnIndex) => {
          updatedBatchData[originalRow][nInput + columnIndex] = value
        })
      })

      // update state
      setBatchData(updatedBatchData)
      setBatchShap(shap)
      setStashedData({
        batch: updatedBatchData.map(arr => arr.slice()),
        batchShap: shap
      })

      enqueueSnackbar("Task Executed Successfully", { variant: "success" })
    }

    const onPollingFailure = (err) => {
      enqueueSnackbar("Error executing task (see console for details)", { variant: "error" })
      console.error(err)
      setIsLoading(false)
    }
    setIsLoading(true)
    // pass to the BatchPrediction func
    let submitted = BatchPrediction(
      workflow_id, inputData, OriginalColumnOrdering, workflow.data.info[workflow.data.active_sheet].column_types, modelID,
      onPostSuccess, onPostFailure, onPollingSuccess, onPollingFailure)

    if (!submitted) {
      setIsLoading(false)
    }
  }

  const ExportToCSV = () => {
    csvLink.current.link.click()
  }

  let generateCSVData = () => {
    let oriFilename = workflow.name
    if (batchData !== undefined && Array.isArray(batchData[0]) && hotTableComponent !== undefined) {
      let csvName = oriFilename.split('.')[0] + '_MS.csv'
      // it's bizarre, but the batchData lags behind on update compared to the table?
      let newData = hotTableComponent.current.hotInstance.getData();
      let newHeaders = getAllHeaders();

      // convert " to \"\" so it's correctly written to csv file
      newHeaders = newHeaders.map((item) => {
        if (item.includes("\"")) {
          item = item.replace("\"", "\"\"")
        }
        return item
      })

      // collect data
      let data = [newHeaders]
      newData.map((item) => {
        data.push(item)
        return null
      })
      setCSVData({
        data: data,
        filename: csvName
      })
    }
    else setCSVData(
      { data: [], filename: 'test.csv' }
    )
  }
  useEffect(generateCSVData, [batchData, workflow, getAllHeaders])

  return (<>
    <div style={{ display: "flex", flexDirection: "row" }}>
      <Typography variant="h6" gutterBottom style={{ textAlign: "left", marginTop: 30 }}>
        Batch Prediction:
      </Typography>
      {isLoading && <CircularProgress size={24} style={{ marginLeft: 20, marginTop: 30 }} />}
    </div>
    <Typography variant="subtitle1" gutterBottom style={{ textAlign: "left", marginLeft: 24, marginBottom: 15 }}>
      Add full rows of input data to the table below to generate predictions using the selected model.
    </Typography>
    <div className={classes.htRoot}>
      <HotTable
        data={batchData}
        colHeaders={getAllHeaders()}
        ref={hotTableComponent}
        rowHeaders={true}
        columnSorting={false}
        height={String(nRows * 24 + 8)}
        stretchH="all"
        minRows={String(nRows)}
        cells={cellStyle()}
        settings={{ outsideClickDeselects: false }}
        style={{ fontSize: "smaller", fontFamily: "Roboto" }}
        id="predict_table2"
        // beforeChange={runBatchPrediction}
      />
      {/*For future use*/}
      {/* {(Object.keys(singleShap).length > 0) && console.log(batchShap)} */}
    </div>
    <CSVLink
      data={csvData.data}
      filename={csvData.filename}
      className="hidden"
      ref={csvLink}
      target="_blank" />
    <Button
      variant="contained"
      color="primary"
      className={classes.margin}
      onClick={ExportToCSV}
      style={{ background: "linear-gradient(45deg, #00c853 30%, #00e676 90%)", float: "right", marginTop: 6 }}
    >
      <DescriptionIcon style={{ marginRight: 12 }} />
      Export Results to Excel
    </Button>
    <Button
      variant="contained"
      color="primary"
      className={classes.margin}
      style={{ marginTop: 6, marginRight: 12, float: "right" }}
      onClick={() => setNRows(nRows + 10)}
    >
      <AddIcon style={{ marginRight: 12 }} />
      Add more rows
    </Button>
    <ButtonWithSpinner
      loading={isLoading}
      variant="contained"
      color="primary"
      className={classes.margin}
      style={{ marginTop: 6, marginRight: 12, float: "right" }}
      onClick={() => runBatchPrediction()}
    >
      Run Prediction
    </ButtonWithSpinner>
  </>)
}