import React, { useEffect, useCallback, useState, useRef, useMemo } from "react"
import { makeStyles } from '@material-ui/core/styles';

import { Button, Paper, Grid, MenuItem, Select, Typography } from "@material-ui/core"

import Divider from '@material-ui/core/Divider';
import List from '@material-ui/core/List';
import DescriptionIcon from "@material-ui/icons/Description"

import Plot from "react-plotly.js";
import { useSnackbar } from "notistack"
import { CSVLink } from "react-csv";
import { useAppStoreKey } from "../../../../AppStore"

import { BatchPrediction } from "../AichemyPredict"
import useModelData from '../useModelData'
import TraceDialog from './TraceDialog'
import SweepPointsDialog from './SweepPointsDialog'
import TraceListItem from './TraceListItem'



const useStyles = makeStyles(() => ({
  root: {
    flexGrow: 1,
    justifyContent: "flex-start",
    spacing: 5,
  },
  plotRegion: {
    width: "60%",
  },
  control: {
    width: "30%"
  },
}));

function SectionTitle(text) {
  return (
    <Typography
      variant="subtitle1"
      gutterBottom
      style={{
        textAlign: "left",
        marginBottom: 5,
        fontWeight: 500
      }}
    >
      {text}
    </Typography>
  )
}

function SelectField({ title, value, onChange, items, disabled }) {
  return (
    <>
      {SectionTitle(title)}
      <Select
        style={{ textAlign: "left", width: "100%" }}
        value={value}
        onChange={onChange}
        disabled={disabled}
      >
        {items?.map(option => (
          <MenuItem key={option} value={option}>
            {option}
          </MenuItem>
        ))}
      </Select>
    </>
  )
}

export default function PredictionSweep({ workflow_id, modelID }) {
  const classes = useStyles();
  const { enqueueSnackbar } = useSnackbar();
  const { GetModelData, OriginalColumnOrdering, ActiveInfo, DecodeCategorical } = useModelData({ modelID: modelID });
  const [stashedData, setStashedData] = useAppStoreKey("MS/Predict/1D-Sweep")
  const [inputCol, setInputCol] = useState(undefined)
  const [otherInputs, setOtherInputs] = useState([])
  const [outputCol, setOutputCol] = useState(undefined)
  const [fixedValues, setFixedValues] = useState({}) // default value for each input dimension
  const [inputSteps, setInputSteps] = useState({}) // map of column to X-steps
  const [traces, setTraces] = useState([])
  // dialog fields
  const [loading, setLoading] = useState(false)
  const [sweepDialogState, setSweepDialogState] = useState(false)
  const [traceDialogState, setTraceDialogState] = useState(false)
  const [traceDialogType, setTraceDialogType] = useState("add")
  const [traceDialogIndex, setTraceDialogIndex] = useState(-1)
  // download fields
  const csvLink = useRef();
  const [csvData, setCSVData] = useState({ data: [], filename: 'sweep.csv' })

  const getFixedValues = useCallback(() => {
    const { data, selected_sheet } = GetModelData()
    const data_df = data.data_df[selected_sheet]
    // get first full row of the data table
    for (let i = 0; i < data_df.length; i++) {
      let full = true
      let untransformed = DecodeCategorical(data_df[i])
      OriginalColumnOrdering.input.forEach(col => {
        full &&= (untransformed[col] !== undefined && untransformed[col] !== null)
      })
      if (full) {
        return { ...untransformed }
      }
    }
    // use the first value from each column that is filled
    let fixed = {}
    OriginalColumnOrdering.input.forEach(col => {
      for (let i = 0; i < data_df.length; i++) {
        let untransformed = DecodeCategorical(data_df[i])
        if (untransformed[col] !== undefined && untransformed[col] !== null) {
          fixed[col] = untransformed[col]
          break
        }
      }
    })
    return fixed
  }, [GetModelData, OriginalColumnOrdering, DecodeCategorical])

  const initDataDict = useCallback((loadedData) => {
    // handle 
    let data = {
      inputCol: "",
      otherInputs: [],
      outputCol: "",
      fixedValues: getFixedValues(),
      inputSteps: {},
      traces: []
    }
    if (loadedData === undefined) {
      return data
    }

    Object.keys(loadedData).forEach((field) => {
      if (loadedData[field] !== undefined) {
        data[field] = loadedData[field]
      }
    })
    return data

  }, [getFixedValues])

  // handle loading stashed state if it exists
  useEffect(() => {
    if (inputCol === undefined || outputCol === undefined) {
      let data = initDataDict(stashedData)
      setInputCol(data.inputCol)
      setOtherInputs(data.otherInputs)
      setOutputCol(data.outputCol)
      setFixedValues(data.fixedValues)
      setInputSteps(data.inputSteps)
      setTraces(data.traces)
      setStashedData(data)
    } else if (stashedData === undefined) {
      setStashedData({
        inputCol: undefined,
        otherInputs: undefined,
        outputCol: undefined,
        fixedValues: undefined,
        inputSteps: undefined,
        traces: undefined
      })
    }
  }, [initDataDict, stashedData, setStashedData, inputCol, outputCol])

  const inputColType = useMemo(() => {
    if (ActiveInfo.column_types === undefined) {
      return "number"
    }
    return ActiveInfo.column_types[inputCol]
  }, [inputCol, ActiveInfo])

  const otherInputData = useMemo(() => {
    let data = {}
    otherInputs.forEach((col) => {
      data[col] = {
        options: ActiveInfo.ori_data_range[col],
        type: ActiveInfo.column_types[col]
      }
    })
    return data
  }, [otherInputs, ActiveInfo])

  const getLayout = () => {
    return {
      autosize: true,
      hovermode: 'closest',
      margin: { b: 50, t: 50 },
      xaxis: { title: { text: inputCol }, showticklabels: true },
      yaxis: { title: { text: outputCol }, showticklabels: true },
      showlegend: true
    }
  }

  const getTrace = useCallback(() => {
    let plotlyTrace = []
    traces.forEach((trace) => {
      if (trace.output[outputCol] !== undefined) {
        let customData = Object.keys(trace.data).map((key) => key + ": " + trace.data[key])
        let hoverTemplate = `${trace.name}<br>` +
          "%{yaxis.title.text}: %{y}<br>" +
          "%{xaxis.title.text}: %{x}<br>" + customData.join("<br>") +
          "<extra></extra>"
        plotlyTrace.push({
          x: inputSteps[inputCol],
          y: trace.output[outputCol],
          columns: [inputCol, outputCol],
          mode: 'lines+markers',
          type: 'scatter',
          marker: { size: 8 },
          hovertemplate: hoverTemplate,
          name: trace.name,
        })
      }
    })
    return plotlyTrace
  }, [traces, outputCol, inputSteps, inputCol])

  const handleInputChange = (ev) => {
    // https://stackoverflow.com/a/66902484
    function linspace(start, stop, num, endpoint = true) {
      const div = endpoint ? (num - 1) : num;
      const step = (stop - start) / div;
      return Array.from({ length: num }, (_, i) => parseFloat((start + step * i).toFixed(2)));
    }
    const col = ev.target.value
    setInputCol(col)
    let remaining = OriginalColumnOrdering.input.filter((val) => val !== col)
    setOtherInputs(remaining)

    let steps = { ...inputSteps }
    let currTraces = [...traces]
    if (inputSteps[col] === undefined) {
      const data_range = ActiveInfo.ori_data_range
      const data_types = ActiveInfo.column_types
      if (data_types[col] === "number") {
        steps[col] = linspace(data_range[col][0], data_range[col][1], 10, true)
      } else if (data_types[col] === "string") {
        steps[col] = [...data_range[col]]
      } else {
        // no idea what to do for smiles here
      }
      // clear the prior traces
      currTraces = []
    }
    setInputSteps(steps)
    setTraces(currTraces)
    setStashedData({
      ...stashedData,
      inputCol: col,
      otherInputs: remaining,
      inputSteps: steps,
      traces: currTraces
    })
  }

  const handleOutputChange = (ev) => {
    setOutputCol(ev.target.value)
    setStashedData({
      ...stashedData,
      outputCol: ev.target.value
    })
  }

  const onListItemClick = (index) => () => {
    setTraceDialogIndex(index)
    setTraceDialogState(true)
    setTraceDialogType("edit")
  }

  const onAddTrace = () => {
    setTraceDialogIndex(traces.length + 1)
    setTraceDialogState(true)
    setTraceDialogType("add")
  }

  const handleCopy = (index) => () => {
    let newTrace = {
      name: traces[index].name + " - Copy",
      data: { ...traces[index].data },
      output: {},
    }
    // copy any values that exist
    OriginalColumnOrdering.output.forEach((col) => {
      if (traces[index].output[col]) {
        newTrace.output[col] = [...traces[index].output[col]]
      }
    })
    setTraceDialogIndex(traces.length)
    setTraces([...traces, newTrace])
    setStashedData({ ...stashedData, traces: [...traces, newTrace] })
  }

  const handleDelete = (index) => () => {
    let newTraces = [...traces]
    newTraces.splice(index, 1)
    setTraceDialogIndex(traces.length - 1)
    setTraces(newTraces)
    setTraceDialogType("add")
    setStashedData({ ...stashedData, traces: newTraces })
  }

  const launchTask = (trace, steps, index) => {
    const onPostFailure = (err) => {
      enqueueSnackbar("Error creating task (see console for details)", { variant: "error" })
      console.error(err)
      setLoading(false)
    }
    const onPollingFailure = (err) => {
      enqueueSnackbar("Error executing task (see console for details)", { variant: "error" })
      console.error(err)
      setLoading(false)
    }
    const onPollingSuccess = (res) => {
      setLoading(false)
      let { data } = res
      let output = {}
      OriginalColumnOrdering.output.forEach((col, col_index) => {
        output[col] = data.map((row) => row[col_index])
      })
      trace.output = output

      let newTraces = [...traces]
      newTraces[index] = trace
      setTraces(newTraces)
      setStashedData({ ...stashedData, traces: newTraces })
    }

    setLoading(true)
    let inputData = steps.map((step) => {
      let row = []
      OriginalColumnOrdering.input.forEach((col) => {
        if (col === inputCol) {
          row.push(step)
        } else {
          row.push(trace.data[col])
        }
      })
      return row
    })
    let submitted = BatchPrediction(
      workflow_id, inputData, OriginalColumnOrdering, ActiveInfo.column_types, modelID,
      undefined, onPostFailure, onPollingSuccess, onPollingFailure
    )
    if (!submitted) {
      setLoading(false)
    }
  }

  const handleAdd = (newTrace) => {
    setTraceDialogIndex(traces.length)
    setTraces([...traces, newTrace])
    setStashedData({ ...stashedData, traces: [...traces, newTrace] })
    launchTask(newTrace, inputSteps[inputCol], traces.length)
  }

  const handleEdit = (updatedTrace, index) => {
    let newTraces = [...traces]
    newTraces[index] = updatedTrace
    setTraces(newTraces)
    setStashedData({ ...stashedData, traces: newTraces })
    setTraceDialogType("add")
    launchTask(updatedTrace, inputSteps[inputCol], index)
  }

  // trace details
  const traceDialogName = () => {
    if (traceDialogType === "add") {
      return "Trace " + (traces.length + 1)
    } else {
      return traces[traceDialogIndex].name
    }
  }

  const traceDialogData = () => {
    if (traceDialogType === "add") {
      return traces.length > 0 ? { ...traces[traces.length - 1].data } : fixedValues
    } else {
      return { ...traces[traceDialogIndex].data }
    }
  }

  const onTraceDialogSuccess = () => {
    if (traceDialogType === "add") {
      return handleAdd
    } else {
      return handleEdit
    }
  }
  const traceDialogTitle = () => {
    if (traceDialogType === "add") {
      return "Add Sweep"
    } else {
      return "Edit Sweep"
    }
  }

  const onTraceCopy = () => {
    if (traceDialogType === "add") {
      return undefined
    } else {
      return handleCopy(traceDialogIndex)
    }
  }

  // sweep points
  const onSweepSuccess = (values) => {
    let steps = { ...inputSteps }
    steps[inputCol] = [...values]
    setInputSteps(steps)
    setStashedData({ ...stashedData, inputSteps: steps })
    traces.forEach((trace, index) => launchTask(trace, values, index))
  }

  //#region csv export
  const ExportToCSV = () => {
    csvLink.current.link.click()
  }

  let generateCSVData = () => {
    if (loading) return

    if (traces.length === 0) {
      setCSVData({ data: [], filename: 'empty.csv' })
      return
    }

    let data = [["Trace", ...OriginalColumnOrdering.input, ...OriginalColumnOrdering.output]]
    traces.forEach((trace) => {
      inputSteps[inputCol].forEach((step, step_index) => {
        let row = [trace.name]
        OriginalColumnOrdering.input.forEach((col) => {
          if (col === inputCol) {
            row.push(step)
          } else {
            row.push(trace.data[col])
          }
        })
        OriginalColumnOrdering.output.forEach((col) => {
          if (trace.output[col] !== undefined && trace.output[col][step_index] !== undefined) {
            row.push(trace.output[col][step_index])
          } else {
            row.push("")
          }
        })
        data.push(row)
      })
    })
    setCSVData({ data: data, filename: 'sweep.csv' })
  }

  useEffect(generateCSVData, [traces, inputSteps, inputCol, loading, OriginalColumnOrdering])
  //#endregion csv export


  return (
    <Grid container className={classes.root}>
      <Grid item xs={12}>
        <Typography variant="h6" gutterBottom style={{ textAlign: "left", marginTop: 24 }}>
          1-D Variable Sweep
        </Typography>
        <Typography gutterBottom style={{ textAlign: "left", marginTop: 24 }}>
          Perform 1D Parameter Sweeps by selecting the input X variable you'd like to vary along with the points to sample this variable over.
          For each new Sweep, you define fixed values for the remaining input variables allowing you to model output behavior for different set points.
        </Typography>
      </Grid>
      <Grid item className={classes.plotRegion}>
        <Plot data={getTrace()} layout={getLayout()} style={{ alignItems: "inherit" }} />
      </Grid>
      <Divider orientation="vertical" variant="middle" flexItem />
      <Grid item className={classes.control}>
        <SelectField
          title="Select Y Variable (Variable to observe)"
          value={outputCol === undefined ? "" : outputCol}
          onChange={handleOutputChange}
          items={OriginalColumnOrdering.output}
          disabled={loading}
        />

        <SelectField
          title="Select X Variable (Variable to sweep over)"
          value={inputCol === undefined ? "" : inputCol}
          onChange={handleInputChange}
          items={OriginalColumnOrdering.input}
          disabled={loading}
        />

        <Button
          variant="contained"
          color="primary"
          component="span"
          onClick={() => setSweepDialogState(true)}
          disabled={inputCol === "" || loading}
          style={{ marginTop: 15, marginLeft: 4, marginRight: 4, width: "100%" }}
        >
          Set X Sweep Points
        </Button>
        <SweepPointsDialog
          title={"Set points to sample for: " + inputCol}
          type={inputColType}
          initialSweepPoints={inputSteps[inputCol]}
          onSuccess={onSweepSuccess}
          open={sweepDialogState}
          setOpen={setSweepDialogState}
        />
        <Button
          variant="contained"
          color="primary"
          component="span"
          onClick={onAddTrace}
          disabled={inputCol === "" || loading}
          style={{ marginTop: 15, marginLeft: 4, marginRight: 4, width: "100%" }}
        >
          Add New Sweep
        </Button>
        <TraceDialog
          title={traceDialogTitle()}
          traceNumber={traceDialogIndex}
          columns={otherInputs}
          columnDetails={otherInputData}
          initialTraceName={traceDialogName()}
          initialTrace={traceDialogData()}
          open={traceDialogState}
          setOpen={setTraceDialogState}
          onSuccess={onTraceDialogSuccess()}
          onCopy={onTraceCopy()}
        />
        <Paper style={{ marginTop: 10 }}>
          <List
            style={{ width: "100%", overflowY: 'scroll', height: 260, border: "1px" }}
          >
            {
              traces.map((trace, index) => {
                return (
                  <TraceListItem
                    key={"trace_list_" + index}
                    name={trace.name}
                    onClick={onListItemClick(index)}
                    onDelete={handleDelete(index)}
                    loading={loading}
                  />
                )
              })
            }
          </List>
        </Paper>
      </Grid>
      <Grid>
        <div>
          <CSVLink
            data={csvData.data}
            filename={csvData.filename}
            className="hidden"
            ref={csvLink}
            target="_blank" />
        </div>
        <Button
          variant="contained"
          color="primary"
          className={classes.margin}
          onClick={ExportToCSV}
          style={{ background: "linear-gradient(45deg, #00c853 30%, #00e676 90%)", float: "right", marginTop: 6 }}
          disabled={traces.length === 0}
        >
          <DescriptionIcon style={{ marginRight: 12 }} />
          Export Results to Excel
        </Button>
      </Grid>
    </Grid>
  )
}