import React, { Component } from "react";
import PropTypes from "prop-types";
import sizeMe from "react-sizeme";
import { Alert } from "react-bootstrap";
import { extent } from "d3-array";
import { scaleBand, scaleLinear, scaleOrdinal } from "d3-scale";
import { schemeCategory10 } from "d3-scale-chromatic";

import { getNiceDomain } from "components/Brushing";
import { primary } from "lib/colors";

import styles from "table/TableScatterMatrix.css";

class Scatter extends Component {
  static propTypes = {
    points: PropTypes.array.isRequired,
    width: PropTypes.number.isRequired,
    height: PropTypes.number.isRequired,
    pearsonCorrelationCoefficient: PropTypes.number,
    colorScale: PropTypes.func.isRequired
  };

  render() {
    const {
      points,
      xDomain,
      yDomain,
      width,
      height,
      colorScale,
      pearsonCorrelationCoefficient
    } = this.props;

    // Scales
    const padding = 1;
    const xScale = scaleLinear()
      .domain(xDomain)
      .range([padding, width - padding]);
    const yScale = scaleLinear()
      .domain(yDomain)
      .range([height - padding, padding]);

    // Presentation variables
    const borderWidth = 1;
    const borderColor = "#000";
    const borderOpacity = 0.25;
    const pointRadius = 2;
    const pointOpacity = 0.7;
    const infoSize = 12;
    const infoOpacity = 0.7;

    // Positioning
    const infoX = width - 2;
    const infoY = height - 3;

    // Elements
    const borderElement = (
      <rect
        x={0}
        y={0}
        width={width}
        height={height}
        fill="none"
        stroke={borderColor}
        strokeWidth={borderWidth}
        opacity={borderOpacity}
      />
    );
    const pointsElement = (
      <g data-testid="scatter-matrix-scatter-points">
        {points.map((point, idx) => {
          const x = xScale(point.x);
          const y = yScale(point.y);
          const fill = colorScale(point.label);
          return (
            <circle
              key={String(idx)}
              cx={x}
              cy={y}
              r={pointRadius}
              fill={fill}
              opacity={pointOpacity}
              data-testid="scatter-matrix-scatter-point"
            />
          );
        })}
      </g>
    );
    const coefficientElement =
      typeof pearsonCorrelationCoefficient === "number" ? (
        <text
          x={infoX}
          y={infoY}
          textAnchor="end"
          fontSize={infoSize}
          opacity={infoOpacity}
          data-testid="scatter-matrix-scatter-pearson-correlation-coefficient"
        >{`r = ${pearsonCorrelationCoefficient.toFixed(5)}`}</text>
      ) : null;
    const totalDataPointsElement = (
      <text
        x={infoX}
        y={infoY - infoSize}
        textAnchor="end"
        fontSize={infoSize}
        opacity={infoOpacity}
        data-testid="scatter-matrix-scatter-total-data-points"
      >{`n = ${points.length}`}</text>
    );

    // Final scatter element
    return (
      <g data-testid="scatter-matrix-scatter">
        {pointsElement}
        {coefficientElement}
        {totalDataPointsElement}
        {borderElement}
      </g>
    );
  }
}

class Axis extends Component {
  render() {
    const { domain, width } = this.props;

    // Data
    const [domainStart, domainEnd] = domain;

    // Presentation variables
    const axisFontSize = 14;
    const xStart = 0;
    const xEnd = width;
    const y = 13;
    const tickLength = 5;
    const tickColor = "#000";
    const tickOpacity = 0.25;
    const tickOffset = 2;
    const tickWidth = 1;

    // Final axis element
    return (
      <g>
        <line
          x1={xStart}
          x2={xStart}
          y1={0.5}
          y2={tickLength}
          stroke={tickColor}
          strokeWidth={tickWidth}
          opacity={tickOpacity}
        />
        <line
          x1={xEnd}
          x2={xEnd}
          y1={0.5}
          y2={tickLength}
          stroke={tickColor}
          strokeWidth={tickWidth}
          opacity={tickOpacity}
        />
        <text
          fontSize={axisFontSize}
          x={xStart}
          y={y}
          textAnchor="start"
          dx={tickOffset}
        >
          {domainStart}
        </text>
        <text
          fontSize={axisFontSize}
          x={xEnd}
          y={y}
          textAnchor="end"
          dx={-tickOffset}
        >
          {domainEnd}
        </text>
      </g>
    );
  }
}
class AxisLabel extends Component {
  render() {
    const { label, x, y, angle = 0 } = this.props;

    // Presentational variables
    const axisLabelFontSize = 14;
    const axisLabelColor = "#000";

    return (
      <text
        fill={axisLabelColor}
        fontSize={axisLabelFontSize}
        x={x}
        y={y}
        transform={`rotate(${angle} ${x} ${y})`}
      >
        {label}
      </text>
    );
  }
}

class Legend extends Component {
  render() {
    const { pointLabels, colorScale } = this.props;

    // Positioning variables
    const pointRadius = 4;
    const lineHeight = 20;
    const margin = 4;

    return (
      <g>
        {pointLabels.map((label, index) => (
          <g key={label}>
            <circle
              cx={pointRadius}
              cy={index * lineHeight - pointRadius}
              r={pointRadius}
              fill={colorScale(label)}
            />
            <text x={2 * pointRadius + margin} y={index * lineHeight}>
              {label}
            </text>
          </g>
        ))}
      </g>
    );
  }
}

class ScatterMatrix extends Component {
  static propTypes = {
    size: PropTypes.shape({
      width: PropTypes.number.isRequired
    }).isRequired,
    axisLabels: PropTypes.arrayOf(PropTypes.string).isRequired,
    pointLabels: PropTypes.arrayOf(PropTypes.string).isRequired
  };

  renderScatterWithAxis(xAxisName, yAxisName, x, y, scatterSize, colorScale) {
    const { axisLabels, scatters } = this.props;
    const id = `scatter-${xAxisName}-${yAxisName}`;
    const scatter = scatters.find(
      s => s.xAxisName === xAxisName && s.yAxisName === yAxisName
    );
    const { points, pearsonCorrelationCoefficient } = scatter;

    const xDomain = getNiceDomain(extent(points, p => p.x));
    const yDomain = getNiceDomain(extent(points, p => p.y));
    const showXAxis = axisLabels.indexOf(yAxisName) === axisLabels.length - 1;
    const showYAxis = axisLabels.indexOf(xAxisName) === axisLabels.length - 1;

    const xAxis = (
      <g
        transform={`translate(0 ${scatterSize})`}
        data-testid="scatter-matrix-x-axis"
      >
        <Axis domain={xDomain} width={scatterSize} />
      </g>
    );
    const yAxis = (
      <g
        transform={`rotate(-90) translate(${-scatterSize} ${scatterSize})`}
        data-testid="scatter-matrix-y-axis"
      >
        <Axis domain={yDomain} width={scatterSize} />
      </g>
    );
    return (
      <g key={id} transform={`translate(${x} ${y})`}>
        <Scatter
          width={scatterSize}
          height={scatterSize}
          points={points}
          xDomain={xDomain}
          yDomain={yDomain}
          pearsonCorrelationCoefficient={pearsonCorrelationCoefficient}
          colorScale={colorScale}
        />
        {showXAxis ? xAxis : null}
        {showYAxis ? yAxis : null}
      </g>
    );
  }

  render() {
    const { axisLabels, pointLabels } = this.props;
    const availableWidth = this.props.size.width;

    // Colors
    const categorical = schemeCategory10;
    const hasPointLabels = pointLabels.length > 0;
    const tooManyPointLabels = pointLabels.length > categorical.length;
    const renderingPointLabels = hasPointLabels && !tooManyPointLabels;
    const uniformColorScale = () => primary;
    const colorScale = renderingPointLabels
      ? scaleOrdinal()
          .domain(pointLabels)
          .range(categorical)
          .unknown("#bbb")
      : uniformColorScale;

    // Sizing variables
    const minScatterSize = 100;
    const maxScatterSize = 200;
    const minMatrixSize = axisLabels.length * minScatterSize;
    const maxMatrixSize = axisLabels.length * maxScatterSize;
    const width = Math.max(
      minMatrixSize,
      Math.min(maxMatrixSize, availableWidth)
    );
    const legendMargin = renderingPointLabels ? 30 : 0;
    const legendHeight = renderingPointLabels ? pointLabels.length * 20 : 0;
    const height = width + legendMargin + legendHeight;
    const labelHeight = 30;
    const axisHeight = 15;
    const matrixSize = width - axisHeight - labelHeight;

    // Positioning variables
    const matrixXOffset = labelHeight;
    const matrixYOffset = 1;

    const scatterPosition = scaleBand()
      .domain(axisLabels)
      .range([0, matrixSize])
      .paddingInner(0.1)
      .round(true);
    const scatterSize = scatterPosition.bandwidth();
    const scatterPositionX = axisName =>
      matrixXOffset + scatterPosition(axisName);
    const scatterPositionY = axisName =>
      matrixYOffset + scatterPosition(axisName);
    const xLabelsOffset = matrixYOffset + matrixSize + axisHeight;
    const axisLabelOffset = 15;
    const legendYOffset = xLabelsOffset + labelHeight + legendMargin;
    const legendXOffset = matrixXOffset;

    return (
      <div className={styles.plotRoot}>
        {tooManyPointLabels && (
          <Alert bsStyle="warning">
            There's too many labels to color the plot ({pointLabels.length})
          </Alert>
        )}
        <svg width={width} height={height} className={styles.plotRootSvg}>
          <g data-testid="scatter-matrix-x-axis-labels">
            {axisLabels.map(axisName => (
              <AxisLabel
                key={axisName}
                label={axisName}
                x={scatterPositionX(axisName)}
                y={axisLabelOffset + xLabelsOffset}
              />
            ))}
          </g>
          <g data-testid="scatter-matrix-y-axis-labels">
            {axisLabels.map(axisName => (
              <AxisLabel
                key={axisName}
                label={axisName}
                x={axisLabelOffset}
                y={scatterPositionY(axisName) + scatterSize}
                angle={-90}
              />
            ))}
          </g>
          <g>
            {axisLabels.map(xAxisName =>
              axisLabels.map(yAxisName => {
                const x = scatterPositionX(xAxisName);
                const y = scatterPositionY(yAxisName);
                return this.renderScatterWithAxis(
                  xAxisName,
                  yAxisName,
                  x,
                  y,
                  scatterSize,
                  colorScale
                );
              })
            )}
          </g>
          {renderingPointLabels && (
            <g
              transform={`translate(${legendXOffset} ${legendYOffset})`}
              data-testid="scatter-matrix-point-labels-legend"
            >
              <Legend pointLabels={pointLabels} colorScale={colorScale} />
            </g>
          )}
        </svg>
      </div>
    );
  }
}

const staticSizeMe = size => Component => props => (
  <Component {...props} size={size} />
);

// Mock the size measurement in test environment
const sized =
  process.env.NODE_ENV === "test"
    ? staticSizeMe({ width: 700 })
    : sizeMe({ refreshRate: 100 });

export default sized(ScatterMatrix);
