import React, { useRef, useState } from 'react';
import { scaleLinear } from 'd3-scale';
import { axisBottom, axisLeft } from 'd3-axis';
import { pointer, select } from 'd3-selection';
import { line } from 'd3-shape';
import styled from 'styled-components';

import colors from '@/constants/colors';

import { DosingStepRow } from '../types';

interface SteppedDosingGraphProps {
  steps: DosingStepRow[];
  source: {
    label?: string;
    unit?: string;
  };
  displayCelsiusTemp: boolean;
  height?: number;
}

interface Point {
  x: number;
  y: number;
}

const DEFAULT_WIDTH = 600;

const ChartSVG = styled.svg`
  .axis-grid .tick {
    color: ${colors.gray['300']};
  }
  .tick text {
    font-size: 11px;
  }

  .axis-label {
    font-size: 14px;
  }

  .coord-label {
    font-size: 12px;
  }

  path {
    transition-property: d;
    transition-duration: 0.15s;
    transition-timing-function: ease-in;
  }
`;

const SteppedDosingGraph = ({
  steps,
  source,
  displayCelsiusTemp = false,
  height = 250,
}: SteppedDosingGraphProps) => {
  const selfRef = useRef();
  const width = selfRef.current ? selfRef.current?.clientWidth : DEFAULT_WIDTH;

  const [activePointIndex, setActivePointIndex] = useState<number | null>(
    null,
  );

  const margin = { top: 20, right: 30, bottom: 60, left: 60 };

  const getLowTemp = (step: DosingStepRow) =>
    displayCelsiusTemp ? step.lowTempC : step.lowTempF;
  const getHighTemp = (step: DosingStepRow) =>
    displayCelsiusTemp ? step.highTempC : step.highTempF;

  const minTemperature = Math.min(
    ...steps.map((s) => Math.min(getLowTemp(s) ?? 0, getHighTemp(s) ?? 0)),
  );
  const normalizedMinTemperature = Math.floor(minTemperature / 10.0) * 10;
  const maxTemperature = Math.max(
    ...steps.map((s) => Math.max(getLowTemp(s) ?? 0, getHighTemp(s) ?? 0)),
  );
  const normalizedMaxTemperature = Math.ceil(maxTemperature / 10.0) * 10;

  const xScale = scaleLinear()
    .domain([normalizedMinTemperature, normalizedMaxTemperature])
    .range([margin.left, width - margin.right]);

  const maxVolume = Math.max(
    ...steps.map((s) =>
      Math.max(s.volumeAtLowTempInMl ?? 0, s.volumeAtHighTempInMl ?? 0),
    ),
  );
  const normalizedMaxVolume =
    maxVolume <= 5 ? maxVolume : Math.ceil(maxVolume / 5.0) * 5;

  const yScale = scaleLinear()
    .domain([0, normalizedMaxVolume])
    .range([height - margin.bottom, margin.top]);

  const xAxis = (ref) => select(ref).call(axisBottom(xScale));
  const yAxis = (ref) => select(ref).call(axisLeft(yScale));

  const xAxisGrid = (ref) =>
    select(ref).call(
      axisBottom(xScale)
        .tickSize(-(height - (margin.top + margin.bottom)))
        .tickFormat(() => ''),
    );
  const yAxisGrid = (ref) =>
    select(ref).call(
      axisLeft(yScale)
        .tickSize(-(width - (margin.left + margin.right)))
        .tickFormat(() => ''),
    );

  const stepPoints: Point[] = [];
  steps.forEach((step) => {
    if (getLowTemp(step) != null && step.volumeAtLowTempInMl != null) {
      stepPoints.push({
        x: getLowTemp(step) ?? 0,
        y: step.volumeAtLowTempInMl,
      });
    }

    if (getHighTemp(step) != null && step.volumeAtHighTempInMl != null) {
      stepPoints.push({
        x: getHighTemp(step) ?? 0,
        y: step.volumeAtHighTempInMl,
      });
    }
  });

  const linePath = line<Point>()
    .x((d) => xScale(d.x))
    .y((d) => yScale(d.y))(stepPoints);

  const stepPointsTooltips = stepPoints.map((point, index) => {
    let labelPositionX = xScale(point.x);
    if (labelPositionX <= margin.left) {
      labelPositionX += 40;
    } else if (labelPositionX >= width - margin.right) {
      labelPositionX -= margin.right;
    }

    let labelPositionY = yScale(point.y);
    if (labelPositionY - 25 < 0) {
      // Position below point
      labelPositionY += 20;
    } else {
      // Position above point
      labelPositionY -= 15;
    }

    return (
      <g key={index}>
        <text
          fill={colors.gray['600']}
          x={labelPositionX}
          y={labelPositionY}
          textAnchor="middle"
          className="coord-label"
        >
          {index === activePointIndex
            ? `${point.y} mL @ ${point.x} ${source.unit}`
            : ''}
        </text>
        <circle
          cx={xScale(point.x)}
          cy={yScale(point.y)}
          r={index === activePointIndex ? 5 : 3}
          fill={colors.cyan['400']}
          strokeWidth={index === activePointIndex ? 2 : 0}
          stroke={colors.cyan['600']}
          style={{ transition: 'ease-out .1s' }}
        />
      </g>
    );
  });

  const handleMouseMove = (e) => {
    const mouseOverTemperature = xScale.invert(pointer(e, this)[0]);

    // Find the closest point from mouse position to highlight
    let newIndex: number | null;
    if (stepPoints.length === 0) {
      newIndex = null;
    } else if (mouseOverTemperature <= stepPoints[0].x) {
      newIndex = 0;
    } else if (mouseOverTemperature >= stepPoints[stepPoints.length - 1].x) {
      newIndex = stepPoints.length - 1;
    } else {
      let i = 0;
      while (
        stepPoints[i].x < mouseOverTemperature &&
        i < stepPoints.length - 1
      )
        i += 1;
      const diffWithPreviousPoint = Math.abs(
        mouseOverTemperature - stepPoints[i - 1].x,
      );
      const diffWithNextPoint = Math.abs(
        stepPoints[i].x - mouseOverTemperature,
      );

      newIndex = diffWithPreviousPoint <= diffWithNextPoint ? i - 1 : i;
    }
    setActivePointIndex(newIndex);
  };

  const handleMouseLeave = () => {
    setActivePointIndex(null);
  };

  return (
    <div
      style={{ height: `${height}px` }}
      className="px-4 flex justify-center items-center"
      ref={selfRef}
      onMouseMove={handleMouseMove}
      onMouseLeave={handleMouseLeave}
    >
      <ChartSVG
        style={{ overflow: 'hidden' }}
        viewBox={`0 0 ${width} ${height}`}
        width={width}
        height={height}
      >
        // vertical and horizontal grid lines
        <g
          transform={`translate(0,${height - margin.bottom})`}
          ref={xAxisGrid}
          className="axis-grid x-axis-grid"
        />
        <g
          transform={`translate(${margin.left},0)`}
          ref={yAxisGrid}
          className="axis-grid y-axis-grid"
        />
        // X and Y axis
        <g
          transform={`translate(0,${height - margin.bottom})`}
          ref={xAxis}
          className="axis x-axis"
        />
        <g
          transform={`translate(${margin.left},0)`}
          ref={yAxis}
          className="axis y-axis"
        />
        // X and Y axis labels
        <text
          transform="rotate(-90)"
          origin="center"
          x={0 - height / 2 - margin.top}
          y={10}
          dy="1em"
          className="axis-label y-axis-label"
        >
          Volume (mL)
        </text>
        <text
          origin="center"
          x={width / 2 - margin.left}
          y={height - margin.top}
          className="axis-label x-axis-label"
        >
          {source.label} ({source.unit})
        </text>
        // line
        <path
          strokeWidth={2.5}
          fill="none"
          stroke={colors.cyan['500']}
          d={linePath}
        />
        {stepPointsTooltips}
      </ChartSVG>
    </div>
  );
};

export default SteppedDosingGraph;
