import React, { useCallback, useEffect, useMemo, useRef, useState } from "react";
import Typography from "@material-ui/core/Typography";
import Tooltip from "@material-ui/core/Tooltip";
import Grid from "@material-ui/core/Grid";
import Divider from '@material-ui/core/Divider';
import Plot from "react-plotly.js";
import { useAppStoreDispatchKey, useAppStoreKey } from "../../../AppStore"
import List from "@material-ui/core/List"
import ListItem from "@material-ui/core/ListItem"
import ListItemIcon from "@material-ui/core/ListItemIcon"
import Checkbox from "@material-ui/core/Checkbox"
import ListItemText from "@material-ui/core/ListItemText"
import ListItemSecondaryAction from "@material-ui/core/ListItemSecondaryAction"
import ExpandMoreIcon from '@material-ui/icons/ExpandMore';
import SizeSIcon from "mdi-material-ui/SizeS"
import SizeMIcon from "mdi-material-ui/SizeM"
import SizeLIcon from "mdi-material-ui/SizeL"
import HelpCircleIcon from "mdi-material-ui/HelpCircleOutline";
import ToggleButtonGroup from "@material-ui/lab/ToggleButtonGroup"
import ToggleButton from "@material-ui/lab/ToggleButton"
import { makeStyles } from "@material-ui/core/styles"
import common_styles from "../../../styles/common_styles"
import useWindowDimensions, { aichemyProtoUpdateWorkflow } from "../utils";
import { useSnackbar } from "notistack";
import AddIcon from '@material-ui/icons/Add';
import { aichemyProtoAxios } from "../../../API/mmAxios";
import { AccordionSummary, AccordionDetails, Accordion } from "./Accordions"
import ModelParams from "./ModelParams";

import useTrainModel from "./useTrainModel";
import ButtonWithSpinner from "../../General/ButtonWithSpinner";
import TaskStatusStepper from "../Optimize/TaskStatusStepper";

const useStyles = makeStyles(theme => ({
    ...common_styles(theme),
    root: {
        width: '100%',
        marginTop: 36,
        marginBottom: 36,
        display: 'flex',
        flexDirection: 'row',

    },
    heading: {
        fontSize: theme.typography.pxToRem(15),
        flexBasis: '33.33%',
        flexShrink: 0,
    },
    secondaryHeading: {
        fontSize: theme.typography.pxToRem(15),
        color: theme.palette.text.secondary,
    },
}));

//#region Mapper Definitions
const plotTypeMapper = {
    parity: 'Parity Plot',
    shap: 'Sensitivity Plot',
    confusion_matrix: 'Confusion Matrix'
}

const plotSizeMapper = {
    s: 4, m: 6, l: 12
}
//#endregion


function DisplayModelPlot({ model, id: modelIdx }) {
    const classes = useStyles();
    const [workflow, setWorkflow] = useAppStoreKey('Workflow')
    const { width } = useWindowDimensions();
    const [isInfoLoading, setInfoLoading] = useState(false);
    const [plotSettingExpanded, setPlotSettingExpanded] = useState(true)
    const [modelParametersExpanded, setModelParametersExpanded] = useState(true)
    const [isPolling, setIsPolling] = useState(false);
    const [shapStatus, setShapStatus] = useState("Not Submitted")
    const setCurrentTab = useAppStoreDispatchKey("MS/Training/CurrentTab")
    const [visibleTabs,] = useAppStoreKey("MS/Training/VisibleTabs")
    const { enqueueSnackbar } = useSnackbar();
    const { PollTask, SubmitShap, GetShapTaskByModelId } = useTrainModel({ enqueueSnackbar })
    const [isLoading, setIsLoading] = useState(false);

    const mounted = useRef(true)

    // generate a array with length N, N = number of plots
    const getNPlots = useMemo(() => {
        if (model.plot) {
            let n = []
            Object.keys(model.plot).forEach((plotType) => {
                model.plot[plotType].forEach(() => {
                    n.push(0)
                })
            })
            return n
        } return []
    }, [model.plot])
    const [plotSize, setPlotSize] = useState(getNPlots.map(() => "m"))
    const [checked, setChecked] = useState(getNPlots.map((plot, idx) => idx));

    useEffect(() => () => { mounted.current = false }, [])
    useEffect(() => {
        if (model.plot.shap) {
            setShapStatus("Finished")
        }
    }, [model])

    const onShapSuccess = useCallback((wf) => {
        const oriNPlots = checked.length
        const newNPlots = Object.keys(wf.models[modelIdx].plot.shap).length
        let newChecked = [...checked]
        let newPlotSize = [...plotSize]
        for (let i = oriNPlots; i < oriNPlots + newNPlots; i++) {
            newChecked.push(i)
            newPlotSize.push('m')
        }
        setChecked(newChecked)
        setPlotSize(newPlotSize)
        setIsPolling(false)
    }, [checked, modelIdx, plotSize])

    const onShapError = useCallback(() => {
        setIsPolling(false)
    }, [setIsPolling])

    useEffect(() => {
        if (!isPolling) {
            let task = GetShapTaskByModelId(modelIdx)
            if (task) {
                setIsPolling(true)
                PollTask(task.task_id, mounted, setShapStatus, onShapSuccess, onShapError)
            }
        }
    }, [isPolling, setIsPolling, modelIdx, GetShapTaskByModelId, PollTask, onShapSuccess, onShapError])

    const computeShap = () => {
        setShapStatus("Queued")
        SubmitShap(modelIdx)
    }

    const getOutputColName = useCallback((idx) => {
        const activeSheet = workflow.models[modelIdx].data[0].active_sheet
        return workflow.models[modelIdx].data[0].info[activeSheet].output_cols[idx]
    }, [modelIdx, workflow.models])

    const addExtraLayout = useCallback((plot, idx) => {
        let newPlot = { data: [...plot.data], layout: { ...plot.layout } }
        if (plotSize) {
            newPlot.layout.width = (width - 57 - 200) / (12 / plotSizeMapper[plotSize[idx]])
        } else {
            newPlot.layout.width = (width - 57 - 200) / 2
        }

        return newPlot
    }, [plotSize, width])

    const getAllPlots = useMemo(() => {
        if (model.plot) {
            let allPlots = []
            let n = 0
            Object.keys(model.plot).forEach((plotType) => {
                model.plot[plotType].forEach((plot, idx) => {
                    allPlots.push({
                        name: plotTypeMapper[plotType] + ' - ' + getOutputColName(idx),
                        plot: addExtraLayout(plot, n)
                    })
                    n += 1
                })
            })
            return allPlots
        } return []
    }, [model.plot, addExtraLayout, getOutputColName])


    const handleTogglePlot = (value) => () => {
        const currentIndex = checked.indexOf(value);
        const newChecked = [...checked];

        if (currentIndex === -1) {
            newChecked.push(value);
        } else {
            newChecked.splice(currentIndex, 1);
        }

        setChecked(newChecked);
    };

    let handlePlotSize = (idx) => (event, newValue) => {
        let newPlotSize = [...plotSize]
        newPlotSize[idx] = newValue
        setPlotSize(newPlotSize)
    }

    const handleUpdateModelInfo = ({ title, description }) => {
        const kwargs = {
            model_info: {
                model_idx: modelIdx,
                name: title,
                description: description ? description : ''
            }
        }

        const update_dict = {
            data_cleaning: {
                EditModelInfo: {
                    kwargs: kwargs
                }
            }
        }
        setInfoLoading(true)
        aichemyProtoUpdateWorkflow(update_dict, workflow, setWorkflow).catch(err => {
            enqueueSnackbar("Failed to update model info.", { variant: "error" });
            console.error(err)
        }).finally(() => { setInfoLoading(false) })
    }

    const modelIsCat = () => {
        return model.info.type === 'classification'
    }

    const hideModel = (modelIdx) => () => {
        setIsLoading(true)
        let url = `workflow/` + workflow.uuid + `/model/${modelIdx}`
        aichemyProtoAxios.delete(url, {
            headers: { "Content-Type": "application/json; charset=utf-8" }
        })
            .then((res) => {
                const wf = res.data
                setWorkflow(wf)
                if (visibleTabs === undefined) {
                    setCurrentTab(undefined)
                } else {
                    let currentTabIdx = visibleTabs.indexOf(modelIdx)
                    if (currentTabIdx === 0) {
                        setCurrentTab(visibleTabs[1])  // 0 is removed, so set to 1
                    } else {
                        setCurrentTab(visibleTabs[currentTabIdx - 1])
                    }
                }
                setIsLoading(false)
            })
            .catch(error => {
                enqueueSnackbar("Model Deletion Failed.", { variant: "error" });
                console.error(error)
            });
    }


    return <>
        {model.plot && <>
            <div className={classes.root} style={{ display: 'flex', flexDirection: 'column' }}>
                {/*Remove model*/}
                <div style={{
                    display: 'flex',
                    justifyContent: 'flex-end',
                    marginLeft: 'auto',
                    marginBottom: 15,
                }}>
                    <ButtonWithSpinner
                        variant="contained"
                        color="secondary"
                        spinnerColor="primary"
                        loading={isLoading}
                        onClick={hideModel(modelIdx)}
                    >
                        Delete Model
                    </ButtonWithSpinner>
                </div>
                <div>
                    <Grid container spacing={2}>
                        <Grid item xs={6}>
                            <Accordion
                                expanded={plotSettingExpanded}
                                onChange={() => setPlotSettingExpanded(!plotSettingExpanded)}

                            >
                                <AccordionSummary
                                    expandIcon={<ExpandMoreIcon />}
                                    aria-controls="panel1bh-content"
                                    id="panel1bh-header"
                                >
                                    <Typography component={'span'} className={classes.heading}>Plot settings</Typography>
                                    <Typography component={'span'} className={classes.secondaryHeading}>Select plots to display</Typography>
                                </AccordionSummary>
                                <AccordionDetails>
                                    <List style={{ flexBasis: "100%" }}>
                                        {getAllPlots.map((plot, idx) => (
                                            <ListItem key={`plot_${idx}`} dense button onClick={handleTogglePlot(idx)}>
                                                <ListItemIcon>
                                                    <Checkbox
                                                        edge="start"
                                                        checked={checked.indexOf(idx) !== -1}
                                                        tabIndex={-1}
                                                        disableRipple
                                                        inputProps={{ 'aria-labelledby': plot.name }}
                                                    />
                                                </ListItemIcon>
                                                <ListItemText id={plot.name} primary={plot.name} />
                                                <ListItemSecondaryAction>
                                                    {/*<IconButton edge="end" aria-label="comments">*/}
                                                    <ToggleButtonGroup
                                                        value={plotSize[idx]}
                                                        exclusive
                                                        onChange={handlePlotSize(idx)}
                                                        aria-label="text alignment"
                                                    >
                                                        <ToggleButton value="s" aria-label="left aligned">
                                                            <SizeSIcon />
                                                        </ToggleButton>
                                                        <ToggleButton value="m" aria-label="centered">
                                                            <SizeMIcon />
                                                        </ToggleButton>
                                                        <ToggleButton value="l" aria-label="right aligned">
                                                            <SizeLIcon />
                                                        </ToggleButton>
                                                    </ToggleButtonGroup>
                                                    {/*</IconButton>*/}
                                                </ListItemSecondaryAction>
                                            </ListItem>
                                        ))}
                                        {(shapStatus === "Not Submitted" && !modelIsCat()) && (
                                            <ListItem key="shap" component="div" dense button onClick={computeShap}>
                                                <ListItemIcon>
                                                    <AddIcon />
                                                </ListItemIcon>
                                                <ListItemText
                                                    id="shap"
                                                    color="red"
                                                    primary={
                                                        <>
                                                            Compute Feature Importance (May take a long time)
                                                            <Tooltip title={
                                                                <Typography variant="subtitle1">
                                                                    This process for larger data sets or more complex model structures can take a significant amount of time to compute. This will run in the background/will update when finished so feel free to close the browser or navigate away.
                                                                </Typography>} placement="top">
                                                                <HelpCircleIcon style={{ fontSize: 15, marginRight: 12 }} />
                                                            </Tooltip>
                                                        </>
                                                    }
                                                />
                                            </ListItem>
                                        )}
                                        {(shapStatus !== "Finished" && shapStatus !== "Not Submitted" && !modelIsCat()) && (
                                            <>
                                                <ListItem key="shap" component="div" dense>
                                                    <ListItemText
                                                        id="shap"
                                                        color="red"
                                                        primary={
                                                            <>
                                                                Computing Feature Importances...
                                                                <Tooltip title={
                                                                    <Typography variant="subtitle1">
                                                                        This process for larger data sets or more complex model structures can take a significant amount of time to compute. This will run in the background/will update when finished so feel free to close the browser or navigate away.
                                                                    </Typography>} placement="top">
                                                                    <HelpCircleIcon style={{ fontSize: 15, marginRight: 12 }} />
                                                                </Tooltip>
                                                            </>
                                                        }
                                                    />
                                                </ListItem>
                                                <TaskStatusStepper taskStatus={shapStatus} />
                                            </>
                                        )}

                                    </List>
                                </AccordionDetails>
                            </Accordion>
                        </Grid>
                        <Grid item xs={6}>
                            <ModelParams
                                modelParametersExpanded={modelParametersExpanded}
                                setModelParametersExpanded={setModelParametersExpanded}
                                handleUpdateModelInfo={handleUpdateModelInfo}
                                isInfoLoading={isInfoLoading}
                                model={model}
                            />
                        </Grid>
                    </Grid>
                </div>
            </div>

            <Divider />
            <Grid container justifyContent="flex-start">
                {getAllPlots.map((plot, idx) => {
                    if (checked.indexOf(idx) > -1) {
                        return <Grid key={'grid_' + idx} item xs={plotSizeMapper[plotSize[idx]]}>
                            <Typography
                                // component={'span'}
                                variant="h6"
                                style={{
                                    textAlign: "left",
                                    marginLeft: 24,
                                    marginRight: 24,
                                    marginTop: 24
                                }}
                            >
                                {plot.name}
                            </Typography>
                            <Plot data={plot.plot.data} layout={plot.plot.layout} />
                        </Grid>
                    } else return ''
                })}
            </Grid>
        </>}
    </>
}

export default React.memo(DisplayModelPlot);