import { useCallback, useMemo } from "react";
import { useAppStoreKey } from "AppStore";
import { getOriColumns, isDataChanged } from "../utils";
import Plot from "react-plotly.js";
import "./parcoords.css";

const MAX_LINES = 200;

const catEncode = (categories, val) => {
  return categories.indexOf(val)
}

const catInputDimension = (data, col, children, categories) => {
  const whichCat = (children, row) => {
    let result = -1
    children.forEach((child, index) => {
      if (row[child] === 1) {
        result = index
      }
    })
    return result;
  }

  let catData = data.map(row => whichCat(children, row))
  let nCat = categories.length
  return {
    label: col,
    range: [0, nCat - 1],
    ori_range: [0, nCat - 1],
    values: catData,
    encode: (val) => { return catEncode(categories, val) },
    column_type: 'categorical',
    tickvals: Array.from(Array(nCat).keys()),
    ticktext: categories
  }
}

const catOutputDimension = (data, col, categories) => {
  const whichCat = (categories, col, row) => {
    return categories.indexOf(row[col])
  }

  let catData = data.map(row => whichCat(categories, col, row))
  let nCat = categories.length
  return {
    label: col,
    range: [0, nCat - 1],
    ori_range: [0, nCat - 1],
    values: catData,
    encode: (val) => { return catEncode(categories, val) },
    column_type: 'categorical',
    tickvals: Array.from(Array(nCat).keys()),
    ticktext: categories
  }
}

const numericDimension = (data, col, range) => {
  return {
    label: col,
    range: [...range],
    ori_range: [...range],
    values: data.filter((el, idx) => idx < MAX_LINES).map(row => row[col]),
    column_type: 'numeric',
  }
}

export default function PredictionChart({ inputData, outputData, hidden, modelID, selectedRow }) {
  const [workflow,] = useAppStoreKey("Workflow");

  let backgroundData = useCallback(() => {
    let sheet = workflow.models[modelID].data[0].active_sheet
    let data = workflow.models[modelID].data[0].data_df[sheet];
    let info = workflow.models[modelID].data[0].info[sheet];

    let input_cols = getOriColumns(info.columns_transform, info.input_cols)
    // build the dimensions
    let dims = []
    input_cols.forEach(col => {
      if (info.columns_transform[col].exclude_from_model) {
        // handle legacy models
        if (info.column_types === undefined || info.column_types === null) {
          let is_cat = false
          let children = info.columns_transform[col].children
          children.forEach((child) => {
            if (info.columns_transform[child]?.transformed_by?.includes("CategoricalEncoder")) {
              is_cat = true;
            }
          })
          if (is_cat) {
            // catInputDimension casts across one-hot encoded columns to find cat
            dims.push(catInputDimension(data, col, children, info.ori_data_range[col]))
          } else {
            dims.push(numericDimension(data, col, info.ori_data_range[col]))
          }
          return
        }
        // new models
        if (info.column_types[col] === "SMILES") {
          // there's NO good way right now to show these
        } else if (info.column_types[col] === "string") {
          dims.push(catInputDimension(data, col, info.columns_transform[col].children, info.ori_data_range[col]))
        } else if (info.column_types[col] === "number") {
          dims.push(numericDimension(data, col, info.ori_data_range[col]))
        }
      } else {
        dims.push(numericDimension(data, col, info.ori_data_range[col]))
      }
    })

    info.output_cols.forEach(col => {
      if (info.ori_data_range[col].some(i => typeof (i) === "string")) {
        // catOutputDimension finds index of value in the data range
        dims.push(catOutputDimension(data, col, info.ori_data_range[col]))
      } else {
        dims.push(numericDimension(data, col, info.ori_data_range[col]))
      }
    })

    if (hidden !== undefined) {
      dims = dims.filter(dim => hidden[dim.label] !== true)
    }

    // if data is changed, do not show the values in dims
    if (isDataChanged(workflow)) {
      dims.forEach(dim => {
        dim.values = []
      })
    }

    let colors = new Array(dims[0].values.length).fill(0) // background fill
    return {
      type: 'parcoords',
      dimensions: dims,
      domain: {
        x: [0.025, 0.975], // pad the sides for axes
      },
      line: {
        color: colors,
        colorscale: [
          [0, '#EEE'], // background
          [0.5, '#ff1744'], // selected
          [1, '#2196f3'], // trace
        ],
      },
      labelangle: -25,
      labelside: "top",
      labelfont: {
        size: 10,
      },

    }
  }, [workflow, hidden, modelID])


  const appendToDim = (dim, data) => {
    let result;
    if (dim.column_type === 'numeric') {
      if (Array.isArray(data)) {
        result = data
      } else {
        result = [data]
      }
    } else {
      if (Array.isArray(data)) {
        result = data.map(val => dim.encode(val))
      } else {
        result = [dim.encode(data)]
      }
    }
    dim.values.push(...result)
  }

  const traceData = useMemo(() => {
    if (inputData === undefined) {
      return [{}]
    }
    // append the new trace to the existing background data
    let data = backgroundData();
    // setup the coloration for traces
    let numTrace = 1
    if (Array.isArray(inputData[data.dimensions[0].label])) {
      numTrace = inputData[data.dimensions[0].label].length
    }
    data.line.color.push(...(new Array(numTrace).fill(1)))

    if (selectedRow !== undefined) {
      data.line.color[selectedRow + data.line.color.length - numTrace] = 0.5
    }

    // push data ontop of the dimensions
    data.dimensions.forEach((dim) => {
      if (inputData[dim.label] !== undefined) {
        appendToDim(dim, inputData[dim.label])
      } else {
        appendToDim(dim, outputData[dim.label])
      }
    })

    // recompute dimension bounds
    data.dimensions.forEach((dim) => {
      if (dim.column_type !== 'numeric') {
        return;
      }
      dim.range[0] = Math.min(...dim.values.concat(dim.ori_range[0]).filter(v => v !== undefined))
      dim.range[1] = Math.max(...dim.values.concat(dim.ori_range[1]).filter(v => v !== undefined))
    })
    return [data]
  }, [backgroundData, inputData, outputData, selectedRow])

  return (
    <>
      <Plot
        data={traceData}
        layout={{
          autosize: true,
          margin: {
            autoexpand: true,
            b: 40, // expand toward the bototm of the panel
          }
        }}
        config={{ displaylogo: false }}
        useResizeHandler={true}
        style={{ width: "100%", height: "100%" }}
      />
    </>

  )
}
