import React, { Component } from "react";
import PropTypes from "prop-types";
import sizeMe from "react-sizeme";
import { stratify, cluster } from "d3-hierarchy";
import { scaleLinear, scaleSequential } from "d3-scale";
import { interpolateRdBu } from "d3-scale-chromatic";
import { max, range } from "d3-array";
import { aperture, difference, flatten, equals, last } from "ramda";

import measureTextWidth from "lib/measureTextWidth";

import styles from "table/TableHeatmap.css";

function translate(x = 0, y = 0) {
  return `translate(${x} ${y})`;
}

const sigmaBandwidth = 3;

export const colorScale = scaleSequential(interpolateRdBu).domain([
  sigmaBandwidth,
  -sigmaBandwidth
]);

const clusterPropType = PropTypes.arrayOf(PropTypes.number);

function constructHierarchy(levels) {
  const clusterAdded = {};
  const clusterList = [];
  const clustersWithParents = aperture(2, levels.map(l => l.clusters));

  function id(cluster) {
    return cluster.join(",");
  }

  function isParent(childCluster, parentCluster) {
    return (
      parentCluster.length > childCluster.length &&
      difference(childCluster, parentCluster).length === 0
    );
  }

  clustersWithParents.forEach(([clusters, parentClusters]) => {
    clusters.forEach(cluster => {
      const clusterId = id(cluster);
      if (clusterAdded[clusterId]) return;

      const parentCluster = parentClusters.find(candidate =>
        isParent(cluster, candidate)
      );
      if (!parentCluster) return;

      clusterAdded[clusterId] = true;
      clusterList.push({
        cluster,
        parentCluster
      });
    });
  });

  const lastCluster = last(levels).clusters[0];
  clusterList.push({
    cluster: lastCluster,
    parentCluster: null
  });

  const stratifyClusterList = stratify()
    .id(n => id(n.cluster))
    .parentId(n => (n.parentCluster ? id(n.parentCluster) : null));
  return stratifyClusterList(clusterList);
}

const clusteringComponentPropTypes = {
  levels: PropTypes.arrayOf(
    PropTypes.shape({
      clusters: PropTypes.arrayOf(clusterPropType).isRequired
    })
  ).isRequired,
  // Cell size is the minimal amount of space between cluster elements
  cellSize: PropTypes.number.isRequired,
  // Total span is the amount of space between root and leaves of the
  // dendrogram
  totalSpan: PropTypes.number.isRequired,
  // Axis is a vector containing numbers 0 to N representing the order of
  // cells on the axis. E.g. [2, 0, 1] means cell number 2 is rendered first,
  // then cell 0, and finally cell number 1.
  axis: PropTypes.arrayOf(PropTypes.number.isRequired)
};

class KMeansClusters extends Component {
  static propTypes = clusteringComponentPropTypes;

  static defaultSpan = 10;

  render() {
    const { levels, axis, cellSize } = this.props;
    const clusters = levels[0].clusters;

    const halfCellSize = cellSize / 2;
    const totalCellSize = axis.length * cellSize;
    const domain = [Math.min(...axis), Math.max(...axis)];
    const scaleY = scaleLinear()
      .domain(domain)
      .range([halfCellSize, totalCellSize - halfCellSize]);

    const clusterGap = 1;
    const clusterHeight = 10;
    const y = element => scaleY(axis.indexOf(element));
    const yStart = element => y(element) - halfCellSize + clusterGap;
    const yEnd = element => y(element) + halfCellSize - clusterGap;

    const stroke = "#666";
    const fill = "#666";

    console.assert(
      equals(axis, flatten(clusters)),
      "The order of KMeans assignments does not match the Heatmap!"
    );

    return clusters.map((cluster, idx) => {
      if (cluster.length === 0) return null;

      const minElement = Math.min(...cluster);
      const maxElement = Math.max(...cluster);
      const clusterStart = yStart(minElement);
      const clusterEnd = yEnd(maxElement);

      return (
        <g key={idx} data-testid="kmeans-cluster">
          <path
            d={`
              M ${clusterHeight},${clusterStart}
              L 0,${clusterStart + clusterHeight / 2}
              L 0,${clusterEnd - clusterHeight / 2}
              L ${clusterHeight},${clusterEnd}
              Z
            `}
            stroke={stroke}
            fill={fill}
          />
        </g>
      );
    });
  }
}

class Dendrogram extends Component {
  static propTypes = clusteringComponentPropTypes;

  static defaultSpan = 100;

  render() {
    const { levels, cellSize, totalSpan, axis } = this.props;

    const hierarchy = constructHierarchy(levels);

    function getNodeCluster(node) {
      return node.data.cluster;
    }

    // Sort the hierarchy so that the leaves order corresponds to Axis prop
    hierarchy.sort((a, b) => {
      const valA = axis.indexOf(Math.min(...getNodeCluster(a)));
      const valB = axis.indexOf(Math.min(...getNodeCluster(b)));
      return valA - valB;
    });

    // Set separation to 1 so all leaves are evenly spaced next to heatmap cells
    const prepareLayout = cluster().separation(() => 1);
    const layout = prepareLayout(hierarchy);
    const leaves = flatten(hierarchy.leaves().map(getNodeCluster));

    console.assert(
      equals(axis, leaves),
      "The order of Dendrogram leaves does not match the Heatmap!"
    );

    const totalCellSize = axis.length * cellSize;
    const scaleY = scaleLinear().range([0, totalCellSize]);
    const scaleX = scaleLinear().range([1, totalSpan]);

    function nodeX(node) {
      return scaleX(node.y);
    }
    function nodeY(node) {
      return scaleY(node.x);
    }

    const links = layout.links();
    const linkElements = links.map(link => {
      const { source, target } = link;
      const key = `${source.id}-${target.id}`;

      return (
        <path
          key={key}
          stroke="#000"
          fill="transparent"
          data-testid="dendrogram-link"
          d={`
          M ${nodeX(link.source)} ${nodeY(link.source)}
          L ${nodeX(link.source)} ${nodeY(link.target)}
          L ${nodeX(link.target)} ${nodeY(link.target)}
        `}
        />
      );
    });

    return <g>{linkElements}</g>;
  }
}

class Legend extends Component {
  render() {
    const { x, y } = this.props;

    const legendPointsCount = 60;
    const legendBarWidth = 60;
    const legendBarHeight = 20;
    const legendBarMargin = 10;
    const legendPointWidth = legendBarWidth / legendPointsCount;

    const legendPoints = range(
      -sigmaBandwidth,
      sigmaBandwidth,
      (2 * sigmaBandwidth) / legendPointsCount
    );

    const legendPointX = scaleLinear()
      .domain([0, legendPointsCount - 1])
      .range([0, legendBarWidth - legendPointWidth]);

    return (
      <g transform={translate(x, y)} data-testid="heatmap-legend">
        <text
          alignmentBaseline="middle"
          x={legendBarMargin}
          y={legendBarHeight / 2}
          dx={-3}
          textAnchor="end"
          fontSize={12}
        >
          -{sigmaBandwidth}σ
        </text>
        <text
          alignmentBaseline="middle"
          x={legendBarMargin + legendBarWidth}
          y={legendBarHeight / 2}
          dx={3}
          textAnchor="start"
          fontSize={12}
        >
          +{sigmaBandwidth}σ
        </text>
        <g transform={translate(legendBarMargin)}>
          {legendPoints.map((value, idx) => (
            <rect
              key={idx}
              x={legendPointX(idx)}
              height={legendBarHeight}
              width={legendPointWidth}
              fill={colorScale(value)}
            />
          ))}
          <rect
            height={legendBarHeight}
            width={legendBarWidth}
            fill="transparent"
            stroke="#666"
            strokeWidth={0.5}
          />
        </g>
      </g>
    );
  }
}

class Heatmap extends Component {
  static propTypes = {
    size: PropTypes.shape({ width: PropTypes.number }).isRequired,
    data: PropTypes.shape({
      x: PropTypes.arrayOf(PropTypes.string).isRequired,
      y: PropTypes.arrayOf(PropTypes.string).isRequired,
      matrix: PropTypes.array.isRequired,
      normalizedMatrix: PropTypes.array.isRequired, // Mean = 0, Variance = 1
      xAxisClusterLevels: PropTypes.arrayOf(
        PropTypes.shape({
          clusters: PropTypes.arrayOf(clusterPropType).isRequired
        })
      ).isRequired,
      yAxisClusterLevels: PropTypes.arrayOf(
        PropTypes.shape({
          clusters: PropTypes.arrayOf(clusterPropType).isRequired
        })
      ).isRequired
    }).isRequired,
    showOriginalData: PropTypes.bool.isRequired
  };

  render() {
    const { data, size, showOriginalData } = this.props;
    const { normalizedMatrix, matrix } = data;
    const xAxisCells = data.x.length;
    const yAxisCells = data.y.length;

    // Axes ordering
    const yAxisOrder =
      data.yAxisClusterLevels.length > 0
        ? flatten(last(data.yAxisClusterLevels).clusters)
        : range(0, yAxisCells);
    const xAxisOrder =
      data.xAxisClusterLevels.length > 0
        ? flatten(last(data.xAxisClusterLevels).clusters)
        : range(0, xAxisCells);

    // Dendrogram / cluster components
    const shouldClusterYAxis = data.yAxisClusterLevels.length === 1;
    const shouldDendrogramYAxis = data.yAxisClusterLevels.length > 1;
    const shouldClusterXAxis = data.xAxisClusterLevels.length === 1;
    const shouldDendrogramXAxis = data.xAxisClusterLevels.length > 1;

    const NoOp = () => null;
    NoOp.defaultSpan = 0;
    const XAxisDendrogramComponent = shouldDendrogramXAxis ? Dendrogram : NoOp;
    const YAxisDendrogramComponent = shouldDendrogramYAxis ? Dendrogram : NoOp;
    const XAxisClusterComponent = shouldClusterXAxis ? KMeansClusters : NoOp;
    const YAxisClusterComponent = shouldClusterYAxis ? KMeansClusters : NoOp;

    // Layout axes
    // | dendrogram | ticks | clusters | heatmap |
    const xAxisTickAngleDegrees = -20;
    const xAxisTickFontSize = 14;
    const xAxisTickAngleRadians = (xAxisTickAngleDegrees / 180) * Math.PI;
    const xAxisTicksHeight =
      xAxisTickFontSize +
      max(data.x, x => measureTextWidth(x, xAxisTickFontSize)) *
        Math.sin(Math.abs(xAxisTickAngleRadians));
    const xAxisTicksWidth =
      measureTextWidth(last(data.x)) *
      Math.cos(Math.abs(xAxisTickAngleRadians));
    const yAxisTicksWidth = max(data.y, y => measureTextWidth(y));
    const yAxisDendrogramWidth = YAxisDendrogramComponent.defaultSpan * 2;
    const xAxisDendrogramHeight = XAxisDendrogramComponent.defaultSpan;
    const yAxisClustersWidth = YAxisClusterComponent.defaultSpan;
    const xAxisClustersHeight = XAxisClusterComponent.defaultSpan;

    const gutter = 5;
    const yAxisDendrogramOffset = 0;
    const yAxisTicksOffset =
      yAxisDendrogramOffset + yAxisDendrogramWidth + yAxisTicksWidth;
    const yAxisClustersOffset = yAxisTicksOffset + gutter;
    const heatmapXOffsetLeft =
      yAxisClustersOffset + yAxisClustersWidth + gutter;

    const xAxisDendrogramOffset = 0;
    const xAxisTicksOffset =
      xAxisDendrogramOffset + xAxisDendrogramHeight + xAxisTicksHeight;
    const xAxisClustersOffset = xAxisTicksOffset + gutter;
    const heatmapYOffset = xAxisClustersOffset + xAxisClustersHeight + gutter;

    // Layout heatmap
    const minCellHeight = 30;
    const minCellWidth = 50;
    const minWidth = xAxisOrder.length * minCellWidth + heatmapXOffsetLeft;
    const width = Math.max(size.width || 0, minWidth);
    // Add right margin when the x-axis tick would be cropped.
    const heatmapXOffsetRight = Math.max(
      0,
      xAxisTicksWidth - (width - heatmapXOffsetLeft) / xAxisCells
    );
    const heatmapWidth = width - heatmapXOffsetLeft - heatmapXOffsetRight;
    const height = heatmapYOffset + yAxisCells * minCellHeight;
    const heatmapHeight = height - heatmapYOffset;
    const cellWidth = heatmapWidth / xAxisCells;
    const cellHeight = heatmapHeight / yAxisCells;

    // Scales
    const xAxisTickX = scaleLinear()
      .domain([0, xAxisCells - 1])
      .range([0, heatmapWidth - cellWidth]);
    const yAxisTickY = scaleLinear()
      .domain([0, yAxisCells - 1])
      .range([cellHeight / 2, heatmapHeight - cellHeight / 2]);

    const cellX = scaleLinear()
      .domain([0, xAxisCells - 1])
      .range([0, heatmapWidth - cellWidth]);
    const cellY = scaleLinear()
      .domain([0, yAxisCells - 1])
      .range([0, heatmapHeight - cellHeight]);

    return (
      <div className={styles.heatmap}>
        <svg width={width} height={height}>
          <Legend x={10} y={1} />
          <g
            transform={translate(0, heatmapYOffset)}
            data-testid="heatmap-y-axis"
          >
            <g
              transform={translate(yAxisDendrogramOffset)}
              data-testid="heatmap-y-axis-dendrogram"
            >
              <YAxisDendrogramComponent
                levels={data.yAxisClusterLevels}
                axis={yAxisOrder}
                cellSize={cellHeight}
                totalSpan={yAxisDendrogramWidth}
              />
            </g>
            <g
              transform={translate(yAxisTicksOffset)}
              data-testid="heatmap-y-axis-ticks"
            >
              {yAxisOrder.map((yAxisIndex, index) => (
                <text
                  key={yAxisIndex}
                  x={0}
                  y={yAxisTickY(index)}
                  textAnchor="end"
                  alignmentBaseline="middle"
                  fontSize={12}
                >
                  {data.y[yAxisIndex]}
                </text>
              ))}
            </g>
            <g
              transform={translate(yAxisClustersOffset)}
              data-testid="heatmap-y-axis-clustering"
            >
              <YAxisClusterComponent
                levels={data.yAxisClusterLevels}
                axis={yAxisOrder}
                cellSize={cellHeight}
                totalSpan={yAxisClustersWidth}
              />
            </g>
          </g>
          <g
            transform={translate(heatmapXOffsetLeft)}
            data-testid="heatmap-x-axis"
          >
            <g
              transform={`
              ${translate(0, xAxisDendrogramOffset)}
              rotate(-90)
              scale(-1, 1)
            `}
              data-testid="heatmap-x-axis-dendrogram"
            >
              <XAxisDendrogramComponent
                levels={data.xAxisClusterLevels}
                axis={xAxisOrder}
                cellSize={cellWidth}
                totalSpan={xAxisDendrogramHeight}
              />
            </g>
            <g
              transform={translate(0, xAxisTicksOffset)}
              data-testid="heatmap-x-axis-ticks"
            >
              {xAxisOrder.map((xAxisIndex, index) => (
                <text
                  key={xAxisIndex}
                  fontSize={12}
                  dx={15}
                  transform={`
                  ${translate(xAxisTickX(index))}
                  rotate(${xAxisTickAngleDegrees})`}
                >
                  {data.x[xAxisIndex]}
                </text>
              ))}
            </g>
            <g
              transform={`
              ${translate(0, xAxisClustersOffset)}
              rotate(-90)
              scale(-1, 1)
            `}
              data-testid="heatmap-x-axis-clustering"
            >
              <XAxisClusterComponent
                levels={data.xAxisClusterLevels}
                axis={xAxisOrder}
                cellSize={cellWidth}
                totalSpan={xAxisClustersHeight}
              />
            </g>
          </g>
          <g
            transform={translate(heatmapXOffsetLeft, heatmapYOffset)}
            data-testid="heatmap-area"
          >
            {yAxisOrder.map((yAxisIndex, yIndex) => (
              <g key={yAxisIndex}>
                {xAxisOrder.map((xAxisIndex, xIndex) => (
                  <g
                    key={xAxisIndex}
                    transform={translate(cellX(xIndex), cellY(yIndex))}
                  >
                    <rect
                      x={0}
                      y={0}
                      width={cellWidth}
                      height={cellHeight}
                      fill={colorScale(
                        normalizedMatrix[yAxisIndex][xAxisIndex]
                      )}
                    />
                    <text
                      x={cellWidth / 2}
                      y={cellHeight / 2}
                      textAnchor="middle"
                      alignmentBaseline="middle"
                      fontSize={10}
                    >
                      {showOriginalData && matrix[yAxisIndex][xAxisIndex]}
                    </text>
                  </g>
                ))}
              </g>
            ))}
          </g>
        </svg>
      </div>
    );
  }
}

const noPlaceholder = process.env.NODE_ENV === "test";
const refreshRate = 100;
export default sizeMe({ refreshRate, noPlaceholder })(Heatmap);
