import {LinePath, Circle} from '@visx/shape';
import {scaleLinear} from '@visx/scale';
import {Grid, useTheme} from '@mui/material';
import {AxisBottom, AxisLeft} from '@visx/axis';
import {Group} from '@visx/group';
import {useEffect, useRef, useState, useCallback, useMemo} from 'react';
import {BasalProfile} from '../../../api/service-types';
import styled from '@emotion/styled';
import {curveStepAfter} from '@visx/curve';
import GridColumns from '@visx/grid/lib/grids/GridColumns';
import GridRows from '@visx/grid/lib/grids/GridRows';

const BasalChart = (basalProfiles: BasalProfile[]) => {
  const [width, setWidth] = useState(0); // responsive
  const height = 300;
  const margin = {top: 20, right: 20, bottom: 30, left: 40};
  const svgRef = useRef<SVGSVGElement>(null);
  const [tooltipData, setTooltipData] = useState<{hour: number; value: number} | null>(null);
  const [tooltipPosition, setTooltipPosition] = useState<{x: number; y: number} | null>(null);
  const [crosshairTime, setCrosshairTime] = useState<number | null>(null);
  const theme = useTheme();

  useEffect(() => {
    if (svgRef.current) {
      setWidth(svgRef.current.clientWidth);
    }

    const handleResize = () => {
      if (svgRef.current) {
        setWidth(svgRef.current.clientWidth);
      }
    };

    window.addEventListener('resize', handleResize);
    return () => window.removeEventListener('resize', handleResize);
  }, []);

  const profileI = basalProfiles[0].profile.map(Number);
  const profileII = basalProfiles[1] && basalProfiles[1]?.profile.map(Number);

  const yScale = scaleLinear({
    domain: [Math.min(...profileI, ...(profileII || [])), Math.max(...profileI, ...(profileII || []))],
    range: [height - margin.bottom, margin.top],
  });

  const xScale = scaleLinear({
    domain: [1, 24],
    range: [margin.left, width - margin.right],
  });

  // y axis tick formatting
  const [yMin, yMax] = yScale.domain();
  const yTicks = yScale.ticks();
  const yTickDist = yTicks[yTicks.length - 1] - yTicks[yTicks.length - 2];
  const yTickPaddingTop = yTicks[yTicks.length - 1] + yTickDist - yMax;
  const yTickPaddingBottom = yMin - yTicks[0] + yTickDist;
  const yTickBottom = parseFloat((yMin - yTickPaddingBottom).toFixed(2));
  const yTickTop = parseFloat((yMax + yTickPaddingTop).toFixed(2));
  yScale.domain([yTickBottom, yTickTop]);

  // tick and grid formatting
  const numTicks = width <= 780 ? xScale.domain()[1] / 2 : xScale.domain()[1];
  const gridColTicks = width >= 780 ? xScale.ticks(numTicks).slice(1) : xScale.ticks(numTicks);
  const gridRowTicks = yScale.ticks().slice(1);

  // mouse hovering window
  const handleMouseMove = useCallback(
    (event: React.MouseEvent<SVGRectElement, MouseEvent>) => {
      const {clientX, clientY} = event;
      const svgRect = svgRef.current?.getBoundingClientRect();
      if (!svgRect) return;
      const mouseXPosition = clientX - svgRect.left;
      const mouseYPosition = clientY - svgRect.top;

      const closestTimePoint = xScale.invert(mouseXPosition);
      let closestIndex = Math.floor(closestTimePoint);
      closestIndex = Math.max(0, Math.min(closestIndex, profileI.length - 1));
      const closestValue = profileI[closestIndex - 1];

      // only show tooltip if near data point
      const linePos = yScale(closestValue);
      const isNearLine = Math.abs(mouseYPosition - linePos) <= 50;
      if (isNearLine) {
        const tooltipXPos = Math.min(Math.max(mouseXPosition + 60, margin.left), width - margin.right - 50);
        const tooltipYPos = Math.min(Math.max(mouseYPosition + 80, margin.top), height - margin.bottom - 1);

        setTooltipData({
          hour: closestIndex,
          value: closestValue,
        });
        setTooltipPosition({
          x: tooltipXPos,
          y: tooltipYPos,
        });
      } else {
        setTooltipData({hour: 0, value: 0});
        setTooltipPosition(null);
      }
      setCrosshairTime(closestTimePoint);
    },
    [profileI, xScale, yScale],
  );

  const handleMouseLeave = useCallback(() => {
    setTooltipData({hour: 0, value: 0});
    setTooltipPosition(null);
    setCrosshairTime(null);
  }, [setCrosshairTime]);

  const crosshairXPosition = useMemo(() => (crosshairTime ? xScale(crosshairTime) : null), [crosshairTime, xScale]);

  const ChartLabel = styled('h5')(() => ({
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'center',
    padding: 0,
    margin: 0,
  }));

  return (
    <Grid container spacing={2}>
      <Grid xs={0.5} display={'flex'} justifyContent="center" alignItems="center" marginLeft="auto">
        <ChartLabel
          style={{
            position: 'absolute',
            writingMode: 'vertical-lr',
            transform: 'rotate(180deg)',
            marginLeft: margin.bottom,
          }}>
          U/h
        </ChartLabel>
      </Grid>
      <Grid item xs={11.5}>
        <div style={{position: 'relative'}}>
          <svg ref={svgRef} width="100%" height={height}>
            <Group>
              {/* Y axis */}
              <AxisLeft
                scale={yScale}
                left={margin.left}
                tickFormat={d => d.toString()}
                stroke="black"
                tickStroke="black"
                tickLabelProps={() => ({
                  fill: 'black',
                  fontSize: 11,
                  textAnchor: 'end',
                  dx: '-0.25em',
                  dy: '0.25em',
                })}
              />

              {/* X axis */}
              <AxisBottom
                top={height - margin.bottom}
                scale={xScale}
                stroke="black"
                tickStroke="black"
                numTicks={numTicks}
                tickFormat={d => d.toString()}
                tickLabelProps={() => ({
                  fill: 'black',
                  fontSize: 11,
                  textAnchor: 'middle',
                  dy: '0.25em',
                })}
              />

              {/* Grid */}
              <GridColumns
                top={margin.top}
                scale={xScale}
                numTicks={numTicks}
                height={height - margin.bottom - margin.top}
                stroke="#e0e0e0"
                strokeDasharray="5,5"
                pointerEvents="none"
                tickValues={gridColTicks}
              />

              <GridRows
                scale={yScale}
                width={width - margin.right - margin.left}
                strokeDasharray="2,2"
                left={margin.left}
                tickValues={gridRowTicks}
              />

              {/* Profile I */}
              <LinePath
                data={profileI}
                x={(d, i) => xScale(i + 1) || 0}
                y={d => yScale(d)}
                stroke="#8B8C89"
                strokeWidth={2}
                curve={curveStepAfter}
              />
              {profileI.map((d, i) => (
                <Circle
                  key={`circle-profileI-${i}`}
                  cx={xScale(i + 1) || 0}
                  cy={yScale(d)}
                  r={4}
                  fill="#8B8C89"></Circle>
              ))}

              {/* Profile II */}
              {profileII && (
                <LinePath
                  data={profileII}
                  x={(d, i) => xScale(i + 1) || 0}
                  y={d => yScale(d)}
                  stroke="red"
                  strokeWidth={2}
                />
              )}
              {profileII &&
                profileII.map((d, i) => (
                  <Circle
                    key={`circle-profileII-${i}`}
                    cx={xScale(i + 1) || 0}
                    cy={yScale(d)}
                    r={4}
                    fill="red"></Circle>
                ))}
              {/* Crosshair */}
              {crosshairXPosition !== null && (
                <line
                  x1={crosshairXPosition}
                  x2={crosshairXPosition}
                  y1={margin.top}
                  y2={height - margin.bottom}
                  stroke="gray"
                  strokeDasharray="4,4"
                />
              )}
              {/* Capture mouse events */}
              <rect
                x={margin.left}
                y={margin.top}
                width={width - margin.left - margin.right}
                height={height - margin.bottom}
                fill="transparent"
                onMouseMove={handleMouseMove}
                onMouseLeave={handleMouseLeave}
              />
            </Group>
          </svg>
          {/* Mouse hovering window */}
          {tooltipData && tooltipPosition && (
            <div
              style={{
                position: 'absolute',
                left: `${tooltipPosition.x}px`,
                top: `${tooltipPosition.y}px`,
                backgroundColor: 'white',
                padding: '5px 10px',
                pointerEvents: 'none',
                transform: 'translate(-50%, -100%)',
                boxShadow: '1px 1px 1px 1px #757575',
                borderRadius: '4px',
                textAlign: 'center',
                maxWidth: '150px',
                whiteSpace: 'nowrap',
                overflow: 'hidden',
                textOverflow: 'ellipsis',
              }}>
              <h4 style={{margin: 0, padding: 0, color: theme.palette.customColors?.lightGrey || 'black'}}>
                Basal profile
              </h4>
              <p style={{fontWeight: 'bold', margin: 0, padding: 0}}>{tooltipData.value.toFixed(2) + ' g/U'}</p>
              <p style={{fontSize: '14px', margin: 0, padding: 0}}>{'at hour ' + tooltipData.hour}</p>
            </div>
          )}
        </div>
      </Grid>
      <Grid xs={12} display={'flex'} justifyContent="center" alignItems="flex-start" style={{marginTop: '1%'}}>
        <ChartLabel style={{position: 'absolute', marginLeft: margin.left}}>Time (h)</ChartLabel>
      </Grid>
    </Grid>
  );
};

export default BasalChart;
