import React, { useEffect, useCallback, useMemo, useState } from "react"
import Grid from "@material-ui/core/Grid"
import Typography from "@material-ui/core/Typography"
import { useSnackbar } from "notistack"
import Link from '@material-ui/core/Link';
import HelpCircleIcon from "mdi-material-ui/HelpCircleOutline";
import { Tooltip } from "@material-ui/core";

import HotTable from "../../../HoT/StyledHoT"
import { useAppStoreKey } from "../../../../AppStore"
import { getOriColumns } from "../../utils"
import PredictionChart from "../PredictionChart";
import AdditiveForceVisualizer from "../../components/AdditiveForceVisualizer"
import useModelData from '../useModelData'
import { SinglePointPrediction } from "../AichemyPredict"

import VariableSlider from "./VariableSlider";
import CategoricalEntry from "./CategoricalEntry";
import SmilesTextField from "./SmilesTextField"
import CircularProgress from "@material-ui/core/CircularProgress";

const ShapDescription = `
  For a given prediction, SHAP will try and estimate the contribution each input variable makes on the prediction 
  process (computed as offsets from the model's mean response or "base value"). The SHAP value for each input below, 
  indicated by the number associated with a variable name, denotes the impact of that variable in arriving at the 
  model's predicted result shown in bold. Larger SHAP values have more effect and are observed as larger bars, with 
  direction shown in both the color and the side of the predicted result that it pushes on.
  `

const ShapLink = "https://christophm.github.io/interpretable-ml-book/shapley.html"

export default function SinglePointPredictionView({ classes, cellStyle, workflow_id, modelID }) {
  const [workflow,] = useAppStoreKey("Workflow");
  const { GetModelData, OriginalColumnOrdering } = useModelData({ modelID: modelID });
  const [stashedData, setStashedData] = useAppStoreKey("MS/Predict/Single")
  const [newData, setNewData] = useState(undefined); //
  const [predData, setPredData] = useState(undefined); // single prediction result
  const [singleShap, setSingleShap] = useState(undefined); // single prediction shap result
  const [hiddenColumns, setHiddenColumns] = useState({}); // which columns should be hidden
  const [isLoading, setIsLoading] = useState(false);
  const { enqueueSnackbar } = useSnackbar();


  const initDataDict = useCallback((predictionData) => {
    let { data } = GetModelData();
    if (!data) return

    let inputColNames = [...OriginalColumnOrdering.input]
    let outputColNames = [...OriginalColumnOrdering.output]
    let newDataDict = {}
    let predDataDict = {}
    let singleShap = {}
    let hidden = {}
    let unknownPred = false

    inputColNames.forEach(item => {
      newDataDict[item] = data.data_df[data.active_sheet][0][item]
      if (newDataDict[item] === undefined) {
        newDataDict[item] = data.info[data.active_sheet].ori_data_range[item][0]
        unknownPred = true
      }
    })
    outputColNames.forEach(item => {
      if (unknownPred) {
        predDataDict[item] = undefined;
      } else {
        predDataDict[item] = data.data_df[data.active_sheet][0][item];
      }
    })
    if (predictionData !== undefined) {
      if (predictionData?.input !== undefined) {
        newDataDict = { ...predictionData.input }
        predDataDict = { ...predictionData.output }
      }
      if (predictionData?.hidden !== undefined) {
        hidden = { ...predictionData.hidden }
      }
    }
    return {
      newData: newDataDict,
      predData: predDataDict,
      hidden: hidden,
      singleShap: singleShap,
    }
  }, [GetModelData, OriginalColumnOrdering])

  // handle loading stashed state if it exists
  useEffect(() => {
    if (newData === undefined || predData === undefined) {
      let data = initDataDict(stashedData)
      setNewData(data.newData)
      setPredData(data.predData)
      setHiddenColumns(data.hidden)
      setSingleShap(data.singleShap)
      setStashedData({
        input: { ...data.newData },
        output: { ...data.predData },
        singleShap: { ...data.singleShap },
        hidden: { ...data.hidden }
      })
    } else if (stashedData === undefined) {
      setStashedData({
        input: undefined,
        output: undefined,
        shap: undefined,
        hidden: undefined
      })
    }
  }, [initDataDict, stashedData, setStashedData, predData, newData])

  const getAllHeaders = useCallback(() => {
    let input = [...OriginalColumnOrdering.input]
    let output = [...OriginalColumnOrdering.output]
    return input.concat(output)
  }, [OriginalColumnOrdering])

  //#region loading data
  const inputData = useMemo(() => {
    let { data, selected_sheet } = GetModelData();
    if (!data) return []

    let ori_data_range = data.info[selected_sheet].ori_data_range
    const inputColNames = OriginalColumnOrdering.input
    let inputData = []
    inputColNames.forEach((item) => {
      let currentType = data.info[selected_sheet].column_types[item] ?
        data.info[selected_sheet].column_types[item] :
        typeof (ori_data_range[item][0])
      if (currentType === 'number') {
        inputData.push(
          {
            name: item,
            type: currentType,
            range: ori_data_range[item],
            selections: ''
          })
      }
      else if (currentType === 'string' || currentType === 'SMILES') {
        inputData.push(
          {
            name: item,
            type: currentType,
            range: ori_data_range[item],
            selections: ori_data_range[item].filter((v, i, a) => a.indexOf(v) === i)
          })
      }
    })
    return inputData
  }, [GetModelData, OriginalColumnOrdering])

  //#endregion loading data
  //#region single point prediction
  const onInputValueChange = (colName, value) => {
    let newValue = { ...newData };
    if (value) {
      if (typeof (value) === 'object') {
        newValue[colName] = value.props.value; // string menu?!
      } else {
        newValue[colName] = value; // slider
      }
    } else {
      newValue[colName] = Number(value); // textbox input
    }
    const inputDataCols = inputData.map(item => item.name)
    let currentNewValue = {}
    Object.keys(newValue).forEach(key => {
      if (inputDataCols.includes(key)) {
        currentNewValue[key] = newValue[key]
      }
    })
    setNewData(newValue)
    runSinglePrediction(currentNewValue, setIsLoading)
  }

  const handleTextChange = (colName, value) => {
    let newValue = { ...newData }
    newValue[colName] = value
    setNewData(newValue)
    runSinglePrediction(newValue)
  }

  function getSinglePredHotData() {
    if (newData === undefined) {
      return [[]]
    }
    let allHeaders = getAllHeaders()
    // fill new data with avg range value
    let filledNewData = {}
    Object.keys(newData).forEach((key) => {
      if (newData[key] === undefined) {
        let range = inputData.find(item => item.name === key).range
        filledNewData[key] = range.reduce((a, b) => a + b) / 2
      } else {
        filledNewData[key] = newData[key]
      }
    })

    let hotData = allHeaders.map((item) => (filledNewData && item in filledNewData) ? filledNewData[item] : predData[item])
    return [hotData]
  }

  const runSinglePrediction = (currentData) => {
    setIsLoading(true)
    const onPredictionSuccess = (response) => {
      let outputColumnNames = [...OriginalColumnOrdering.output]
      let newPredData = { ...predData }
      outputColumnNames.forEach((item, idx) => newPredData[item] = response['result'][0][idx])
      setPredData(newPredData)
      setSingleShap(response['shap'])
      // stash the result
      let savedData = { ...stashedData }
      savedData.input = { ...currentData }
      savedData.output = { ...newPredData }
      savedData.singleShap = response['shap']
      setStashedData(savedData)
      setIsLoading(false)
    }
    const onPredictionFailure = (err) => {
      enqueueSnackbar('Prediction Error (see console for details)', { variant: "error" })
      setIsLoading(false)
      console.error(err)
    }

    // fill empty prediction data with avg data range
    Object.keys(currentData).forEach((item) => {
      if (currentData[item] === undefined) {
        let range = inputData.find(d => d.name === item).range
        currentData[item] = range.reduce((a, b) => a + b) / 2
      }
    })

    SinglePointPrediction(
      workflow_id,
      currentData,
      modelID,
      onPredictionSuccess,
      onPredictionFailure
    )
  }
  //#endregion single point prediction

  //#region column hiding
  const onHideColumn = (column) => {
    let hidden = { ...hiddenColumns };
    hidden[column] = true;
    setHiddenColumns(hidden)
    // stash the result
    let savedData = { ...stashedData }
    savedData.hidden = { ...hidden }
    setStashedData(savedData)
  }

  const onUnhideColumn = (column) => {
    let hidden = { ...hiddenColumns };
    hidden[column] = false
    setHiddenColumns(hidden)
    // stash the result
    let savedData = { ...stashedData }
    savedData.hidden = { ...hidden }
    setStashedData(savedData)
  }
  //#endregion column hiding

  //#region set shap
  const validShap = () => {
    return singleShap !== undefined && Object.keys(singleShap).length > 0 && singleShap.values !== undefined && singleShap.values.length > 0
  }

  // prepare the data for SHAP force plot
  const getSingleShap = () => {
    let result
    const activeSheet = workflow.models[modelID].data[0].active_sheet
    let allOutputCols = workflow.models[modelID].data[0].info[activeSheet].output_cols
    let allInputCols = workflow.models[modelID].data[0].info[activeSheet].input_cols
    // get the original columns
    const selected_sheet = workflow.models[modelID].data[0].active_sheet
    const transforms = workflow.models[modelID].data[0].info[selected_sheet].columns_transform;
    allOutputCols = getOriColumns(transforms, allOutputCols)
    allInputCols = getOriColumns(transforms, allInputCols)
    const n_input = allInputCols.length
    result = allOutputCols.map((col, outIdx) => {
      return {
        featureNames: allInputCols,
        features: Object.assign({}, ...allInputCols.map((item, inIdx) => {
          let value
          // I don't know why, but the values are not always in the same shape...
          if (singleShap.values[0].length === n_input) {
            value = singleShap.values[0][inIdx][outIdx]
          } else {
            value = singleShap.values[0][outIdx][inIdx]
          }
          // if value is too small, SHAP force plot doesn't render correctly
          if (Math.abs(value) <= 1e-8) value = 0
          return { [inIdx]: { value: value, effect: value } }
        })),
        baseValue: singleShap.base_values[0][outIdx],
        link: 'identity',
        outNames: [col]
      }
    })
    return result
  }
  //#endregion set shap

  // on initial load, render nothing until the useEffect is called
  if (newData === undefined) {
    return <></>
  }

  return (<>
    <Grid container justifyContent="flex-start" alignItems="center" spacing={5}>
      <Grid item xs={6}>
        <Typography variant="h6" gutterBottom style={{ textAlign: "left", marginTop: 24 }}>
          Single Point Model Prediction:
        </Typography>

      </Grid>
    </Grid>

    {/* Generate chart of prediction */}
    <Grid container>
      <Grid item xs={12}>
        <PredictionChart inputData={newData} outputData={predData} hidden={hiddenColumns} modelID={modelID} />
      </Grid>
    </Grid>


    {/* Plot each input variable */}
    <Grid container justifyContent="flex-start" alignItems="flex-start" spacing={5}>
      <Grid item xs={12}>
        <Typography variant="subtitle1" gutterBottom style={{ textAlign: "left", marginLeft: 24, marginTop: 15 }}>
          Choose input value
        </Typography>
      </Grid>
      {newData && inputData.map(item => {
        if (item.type === 'number') {
          return (
            <VariableSlider
              key={"var_" + item.name}
              xs={4} xm={3} xl={2}
              variableName={item.name}
              minValue={item.range[0]}
              maxValue={item.range[1]}
              initialValue={newData[item.name] !== undefined ? newData[item.name] : 0.5 * (item.range[0] + item.range[1])}
              initialHidden={!!hiddenColumns[item.name]}
              onChangeCommitted={onInputValueChange}
              onVisible={() => { onUnhideColumn(item.name) }}
              onHidden={() => { onHideColumn(item.name) }}
              isLoading={isLoading}
            />
          );
        }
        else if (item.type === "SMILES") {
          return (
            <SmilesTextField
              key={"smiles_" + item.name}
              xs={4} xm={3} xl={2}
              variableName={item.name}
              options={item.selections}
              initialValue={newData[item.name]}
              onChangeCommitted={handleTextChange}
              onVisible={() => { onUnhideColumn(item.name) }}
              onHidden={() => { onHideColumn(item.name) }}
              isLoading={isLoading}
            />
          )
        }
        else {
          return (
            <CategoricalEntry
              key={"cat_ent_" + item.name}
              xs={4} xm={3} xl={2}
              variableName={item.name}
              options={item.selections}
              initialValue={newData[item.name]}
              onChangeCommitted={handleTextChange}
              onVisible={() => { onUnhideColumn(item.name) }}
              onHidden={() => { onHideColumn(item.name) }}
              isLoading={isLoading}
            />
          )
        }
      })}
    </Grid>
    <div style={{ display: 'flex', flexDirection: 'row' }}>
      {/* Plot each output variable */}
      <Typography variant="subtitle1" gutterBottom style={{ textAlign: "left", marginLeft: 24, marginBottom: 15, marginTop: 24 }}>
        The prediction results:
      </Typography>
      {isLoading && <CircularProgress size={24} style={{ marginLeft: 20, marginTop: 24 }} />}
    </div>

    <div className={classes.htRoot}>
      <HotTable
        data={getSinglePredHotData()}
        colHeaders={getAllHeaders()}
        // Make predictions and populate the columns
        rowHeaders={true}
        columnSorting={false}
        height="75"
        stretchH="all"
        minRows="1"
        editable={false}
        settings={{ outsideClickDeselects: false, readOnly: true }}
        style={{ fontSize: "smaller", fontFamily: "Roboto" }}
        id="predict_table1"
        cells={cellStyle()}
      />
    </div>
    {validShap() && <div style={{ display: "flex", flexDirection: "row" }}>
      <Typography variant="subtitle1" gutterBottom style={{ textAlign: "left", marginLeft: 24, marginBottom: 15, marginTop: 24 }}>
        The input contribution analysis
      </Typography>
      <Typography variant="subtitle1" gutterBottom style={{ textAlign: "left", marginLeft: 4, marginBottom: 15, marginTop: 24 }}>
        <Link href={ShapLink} target="_blank" rel="noreferrer noopener">
          (SHAP)
        </Link>
      </Typography>
      <Tooltip title={<Typography variant="subtitle1">{ShapDescription}</Typography>} placement="top" classes={{ tooltip: classes.customWidth }}>
        <HelpCircleIcon style={{ fontSize: 15, marginRight: 12, marginTop: 24 }} />
      </Tooltip>
    </div>}
    {validShap() && getSingleShap().map((shap) => <div>
      <div style={{ display: 'flex', flexDirection: 'row' }}>
        <Typography variant="subtitle1" gutterBottom style={{ textAlign: "left", marginLeft: 24, marginBottom: 15, marginTop: 24 }}>
          {shap.outNames}:
        </Typography>
        {isLoading && <CircularProgress size={24} style={{ marginLeft: 20, marginTop: 24 }} />}
      </div>
      <div style={{ backgroundColor: "white" }}>
        <AdditiveForceVisualizer
          featureNames={shap.featureNames}
          features={shap.features}
          baseValue={shap.baseValue}
          link={shap.link}
          outNames={shap.outNames}
        />
      </div>
    </div>
    )}
  </>)
}
