import * as d3 from 'd3'
import * as _ from 'lodash'
import BaseChart from "./BaseChart"

const DEFAULT_CONFIGURATION = {
  margin: { top: 0, right: 0, bottom: 0, left: 0 },
  labelPadding: { x: 0, y: 0 },
  chartPadding: { x: 0, y: 0 },
  scale: { 
    x: {
      scale: d3.scaleLinear(),
      ticks: undefined,
      tickSize: undefined,
      tickFormat: undefined,
    }, 
    y: {
      scale: d3.scaleLinear(),
      ticks: undefined,
      tickSize: undefined,
      tickFormat: undefined
    }
  },
  renderAxis: true,
  styling: [
    {
      stroke: "red",
      fill: "rgba(255, 0, 0, 0.2)"
    },
    {
      stroke: "blue",
      fill: "rgba(0, 0, 255, 0.2)"
    }
  ],
  useGaussian: false
}

const normalPDF = (x, mu, sigma) => {
  return Math.exp(-Math.pow((x-mu),2)/(2*Math.pow(sigma,2))) / (sigma*Math.sqrt(2*Math.PI))
}

const createDistribution = (mu, sigma) => {
  var distribution = []
  const lowerBound = mu - sigma * 4
  const upperBound = mu + sigma * 4
  var stepSize = (upperBound - lowerBound) / 200
  for (var i = lowerBound; i < upperBound; i += stepSize) {
    distribution.push({ x: i, y: normalPDF(i, mu, sigma) })
  }
  return distribution
}

const getGaussian = (values) => {
  const mu = d3.mean(values)
  const sigma = d3.deviation(values)
  return createDistribution(mu, sigma)
}

const epanechnikov = (bandwidth) => {
  return x => Math.abs(x /= bandwidth) <= 1 ? 0.75 * (1 - x * x) / bandwidth : 0;
}

const kde = (kernel, thresholds, data) => {
  return thresholds.map(t => {
    return {
      x: t, 
      y: d3.mean(data, d => kernel(t - d))
    }
  })
}

const getKDE = (values, chartObject) => {
  const bandwidth = 10

  const padding = chartObject.configuration.chartPadding
  const range = d3.extent(values, (x) => x)
  const delta = Math.abs(range[1] - range[0])
  const x = d3.scaleLinear()
    .domain([
      range[0] * (1 - padding.x) - delta * 0.005, 
      range[1] * (1 + padding.x) + delta * 0.015
    ])
    .range([0, chartObject.chartDimensions.innerWidth])
    .nice()

  const thresholds = x.ticks(40)
  return kde(epanechnikov(bandwidth), thresholds, values)
}

class DensityPlot extends BaseChart {
  constructor(containerId, data, configuration, xMapping, yMapping) {
    super(containerId, data, configuration, xMapping, yMapping)
    this.configuration = _.merge({}, DEFAULT_CONFIGURATION, this.configuration)

    this.setDistributionData(data)
    
    const { scales, axes } = this.getAxis()
    this.scales = scales
    this.axes = axes
  }

  renderChart() {
    const areaGenerator = d3.area()
      .x(d => this.scales.xScale(this.xMapping(d)))
      .y0(d => this.scales.yScale(0))
      .y1(d => this.scales.yScale(this.yMapping(d)))
      .curve(d3.curveBasis)      

    const g = this.getChartContainer().select(".chart-container")
    g.selectAll(".distribution-path").remove()
    g.selectAll(".distribution-area").remove()
    this.data.forEach((x, idx) => {
      g.append('path')
        .attr('class', 'distribution-area')
        .style("fill", this.configuration.styling[idx].fill)
        .style('stroke-width', 1.5)
        .style('stroke', this.configuration.styling[idx].stroke)
        .attr('d', areaGenerator(x.data))
    })
  }

  setDistributionData(data) {
    this.data = []
    if (this.configuration.useGaussian) {
      data.forEach((x) => {
        this.data.push({...x, data: getGaussian(x.data)})
      })
    }
    else {
      data.forEach((x) => {
        this.data.push({...x, data: getKDE(x.data, this)})
      })
    }
  }

  updateData(newData) {
    this.setDistributionData(newData)
    this.clearChart()
    this.setUpChart()
    this.renderChart()
  }
}

export default DensityPlot