import React, { Component } from "react";
import PropTypes from "prop-types";
import {
  XAxis,
  YAxis,
  Bar,
  BarChart,
  Tooltip as RechartsTooltip
} from "recharts";
import {
  flatten,
  groupWith,
  equals,
  sort,
  pipe,
  map,
  range,
  head,
  last
} from "ramda";
import { scaleLinear, scaleOrdinal } from "d3-scale";
import { extent } from "d3-array";

import { getNiceDomain } from "components/Brushing";
import ResponsiveContainer from "components/ResponsiveContainer";
import Tooltip, {
  TooltipGroup,
  TooltipTitle,
  TooltipValue
} from "containers/Entities/TaxonomicAnalysis/Biom/Tooltip";

import { categoricalSimple } from "lib/colors";

class DotTooltip extends Component {
  render() {
    const { payload: rawPayload } = this.props;

    if (rawPayload === null || !rawPayload[0]) return null;
    const payload = rawPayload[0].payload;

    const showMean = typeof payload.mean === "number";
    const formattedSize = payload.size;
    const formattedMean = payload.mean;

    return (
      <Tooltip>
        <TooltipTitle>{payload.name}</TooltipTitle>
        <TooltipGroup>
          <TooltipValue>Sample size: {formattedSize}</TooltipValue>
          {showMean && <TooltipValue>Mean: {formattedMean}</TooltipValue>}
        </TooltipGroup>
      </Tooltip>
    );
  }
}

class Points extends Component {
  renderPoints(center) {
    const { scale, colorScale, payload } = this.props;

    const dotRadius = 3;
    const dotMargin = 3;

    function makeScaleForPoints(points) {
      const count = points.length;
      const r = dotRadius;

      // Draw points side-by side if multiple at the same level
      const width = count * 2 * r + (count - 1) * dotMargin;
      const ids = range(0, count);
      const cx = scaleLinear()
        .domain([head(ids), last(ids)])
        .range([center - width / 2 + r, center + width / 2 - r]);

      const cy = scale(head(points));
      return { cx, cy, r, ids };
    }
    const sortAscending = sort((a, b) => a - b);

    const points = pipe(
      sortAscending,
      groupWith(equals),
      map(makeScaleForPoints)
    )(payload.points);

    const fill = colorScale(payload.name);
    const stroke = "#000";

    return (
      <g data-testid="dot-plot-points">
        {points.map(({ cx, cy, r, ids }) => (
          <g key={cy}>
            {ids.map(id => (
              <circle
                key={id}
                cx={cx(id)}
                cy={cy}
                r={r}
                fill={fill}
                stroke={stroke}
                data-testid="dot-plot-point"
              />
            ))}
          </g>
        ))}
      </g>
    );
  }

  renderMean(center) {
    const { width, payload, scale } = this.props;

    const mean = payload.mean;
    const showMean = typeof mean === "number";
    const meanY = scale(mean);
    const stroke = "#000";
    const lineWidth = 2;
    const maxLineLength = 100;
    const lineLength = Math.min(maxLineLength, width);

    return (
      showMean && (
        <line
          x1={center - lineLength / 2}
          y1={meanY}
          x2={center + lineLength / 2}
          y2={meanY}
          stroke={stroke}
          strokeWidth={lineWidth}
          data-testid="dot-plot-mean"
        />
      )
    );
  }

  render() {
    const { x, width } = this.props;
    const center = x + width / 2;

    return (
      <g data-testid="dot-plot-points">
        {this.renderMean(center)}
        {this.renderPoints(center)}
      </g>
    );
  }
}
class DotPlot extends Component {
  static propTypes = {
    series: PropTypes.arrayOf(
      PropTypes.shape({
        name: PropTypes.string.isRequired,
        points: PropTypes.arrayOf(PropTypes.number).isRequired,
        size: PropTypes.number.isRequired,
        mean: PropTypes.number
      })
    ).isRequired
  };

  render() {
    const { series } = this.props;

    const allPoints = flatten(series.map(series => series.points));
    const domain = getNiceDomain(extent(allPoints));
    const scale = scaleLinear().domain(domain);
    const colorScale = scaleOrdinal()
      .domain(series.map(series => series.name))
      .range(categoricalSimple);

    const rechartsData = series;
    const dataKey = "mean";
    const width = "100%";
    const height = 300;
    const margin = { left: 20, top: 5 };
    const yAxisLabel = {
      value: "Value",
      angle: -90,
      position: "left",
      dy: -10
    };
    const tooltipCursor = { opacity: 0.5 };
    const tooltipOffset = 50;

    return (
      <ResponsiveContainer width={width} height={height}>
        <BarChart data={rechartsData} margin={margin}>
          <XAxis dataKey="name" interval={0} />
          <YAxis
            dataKey={dataKey}
            domain={domain}
            scale={scale}
            label={yAxisLabel}
          />
          <RechartsTooltip
            content={<DotTooltip />}
            cursor={tooltipCursor}
            offset={tooltipOffset}
          />
          <Bar
            dataKey={dataKey}
            isAnimationActive={false}
            shape={<Points scale={scale} colorScale={colorScale} />}
          />
        </BarChart>
      </ResponsiveContainer>
    );
  }
}

export default DotPlot;
