import { AxisBottom, AxisLeft } from "@visx/axis";
import { GridColumns, GridRows } from "@visx/grid";
import { useCallback, useContext, useEffect, useMemo, useState } from "react";
import AutoSizer, { Size } from "react-virtualized-auto-sizer";
import { DatapointAggregate } from "@cognite/sdk";
import { scaleLinear, scaleTime } from "@visx/scale";
import dayjs, { Dayjs } from "@properate/dayjs";
import { Area, Bar, Line, LinePath } from "@visx/shape";
import { MarkerCircle } from "@visx/marker";
import { Group } from "@visx/group";
import {
  formatUnit,
  getFractionDigits,
  DERIVED_UNITS,
} from "@properate/common";
import { defaultStyles, TooltipWithBounds, useTooltip } from "@visx/tooltip";
import { localPoint } from "@visx/event";
import { Col, Row } from "antd";
import { ThemeContext } from "styled-components";
import _ from "lodash";
import Color from "color";
import { RectClipPath } from "@visx/clip-path";
import { datapointsRetrieve } from "@/utils/helpers";
import { formatMultiState, formatNonScientific } from "@/utils/format";
import { useCogniteClient } from "@/context/CogniteClientContext";
import { getDatapointClosestToDate } from "@/utils/datapoints";
import { maxPixelLengthInStringArray } from "@/pages/common/SchemaGraph/utils";

type InnerGraphProps = {
  timeseriesId: number;
  width: number;
  height: number;
  min?: number;
  max?: number;
  minTimeseriesId?: number;
  maxTimeseriesId?: number;
  start: Dayjs;
  end: Dayjs;
  unit: string;
  timeseriesRoomId?: number;
  yAxisMin?: number;
  yAxisMax?: number;
  stateDescription?: Record<number, string>;
};

const AGGREGATION_THRESHOLD = 1000;

const clampMinMax = (
  data: {
    id: number;
    datapoints: {
      timestamp: Date;
      min: number;
      max: number;
      average: number;
    }[];
  }[],
  yAxisMin?: number,
  yAxisMax?: number,
) => {
  if (yAxisMin !== undefined && yAxisMax !== undefined) {
    return data.map((res) => ({
      ...res,
      datapoints: res.datapoints.map((dp) => ({
        ...dp,
        min:
          dp.min > yAxisMax ? yAxisMax : dp.min < yAxisMin ? yAxisMin : dp.min,
        max:
          dp.max > yAxisMax ? yAxisMax : dp.max < yAxisMin ? yAxisMin : dp.max,
        average:
          dp.average > yAxisMax
            ? yAxisMax
            : dp.average < yAxisMin
            ? yAxisMin
            : dp.average,
      })),
    }));
  }
  if (yAxisMin !== undefined) {
    return data.map((res) => ({
      ...res,
      datapoints: res.datapoints.map((dp) => ({
        ...dp,
        max: dp.max < yAxisMin ? yAxisMin : dp.max,
        min: dp.min < yAxisMin ? yAxisMin : dp.min,
        average: dp.average < yAxisMin ? yAxisMin : dp.average,
      })),
    }));
  }
  if (yAxisMax !== undefined) {
    return data.map((res) => ({
      ...res,
      datapoints: res.datapoints.map((dp) => ({
        ...dp,
        max: dp.max > yAxisMax ? yAxisMax : dp.max,
        min: dp.min > yAxisMax ? yAxisMax : dp.min,
        average: dp.average > yAxisMax ? yAxisMax : dp.average,
      })),
    }));
  }
  return data;
};

const InnerGraph = ({
  timeseriesId,
  width,
  start,
  end,
  min,
  max,
  minTimeseriesId,
  maxTimeseriesId,
  timeseriesRoomId,
  unit,
  yAxisMin,
  yAxisMax,
  stateDescription,
}: InnerGraphProps) => {
  const { client } = useCogniteClient();
  const themeContext = useContext(ThemeContext);
  const [data, setData] = useState<any[]>();
  const [aggregated, setAggregated] = useState<boolean>();
  const [marker, setMarker] = useState<any>();

  const [roomNames, setRoomNames] = useState<
    Record<number, string> | undefined
  >();

  const {
    tooltipData,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip<JSX.Element>();

  const granularity: "h" | "24h" | "5m" = useMemo(() => {
    const hours = (end.valueOf() - start.valueOf()) / 1000 / 60 / 60;
    if (hours <= 24) {
      return "5m";
    }
    if (hours < 1000) {
      return "h";
    }
    if (hours / 24 < 1000) {
      return "24h";
    }
    return "h";
  }, [start, end]);

  useEffect(() => {
    const get = async () => {
      const pointCount = (
        (
          await client.datapoints.retrieve({
            aggregates: ["count"],
            granularity: "365d",
            items: [{ id: timeseriesId }],
            start: start.valueOf(),
            end: end.valueOf(),
          })
        )[0].datapoints[0] as DatapointAggregate
      )?.count;

      const results =
        pointCount && pointCount < AGGREGATION_THRESHOLD
          ? (
              await client.datapoints.retrieve({
                items: minTimeseriesId
                  ? [{ id: timeseriesId }, { id: minTimeseriesId }]
                  : maxTimeseriesId
                  ? [{ id: timeseriesId }, { id: maxTimeseriesId }]
                  : [{ id: timeseriesId }],
                start: start.valueOf(),
                end: end.valueOf(),
                limit: AGGREGATION_THRESHOLD,
              })
            ).map((r) => ({
              ...r,
              datapoints: r.datapoints.map(
                (d: any) =>
                  ({
                    ...d,
                    average: d.value,
                  }) as DatapointAggregate,
              ),
            }))
          : await datapointsRetrieve(client)({
              start: start.valueOf(),
              end: end.valueOf(),
              items: minTimeseriesId
                ? [{ id: timeseriesId }, { id: minTimeseriesId }]
                : maxTimeseriesId
                ? [{ id: timeseriesId }, { id: maxTimeseriesId }]
                : [{ id: timeseriesId }],
              aggregates: ["average", "min", "max"],
              granularity,
              includeOutsidePoints: false,
              limit: 1000,
            });

      const cunit = ["%-akt", "%-norm"].includes(unit || "") ? "%" : unit || "";
      const runit = ["%-akt", "%-norm"].includes(results[0].unit || "")
        ? "%"
        : results[0].unit || "";

      if (cunit !== runit) {
        const newResults = results.map((result) => ({
          ...result,
          datapoints: result.datapoints.map((val: any) => {
            const average = DERIVED_UNITS[runit][cunit].to(val.average);
            const min =
              val.min !== undefined
                ? DERIVED_UNITS[runit][cunit].to(val.min)
                : undefined;
            const max =
              val.min !== undefined
                ? DERIVED_UNITS[runit][cunit].to(val.min)
                : undefined;
            return {
              ...val,
              average,
              min,
              max,
            };
          }),
        }));

        setData(clampMinMax(newResults, yAxisMin, yAxisMax));
        setAggregated(!pointCount || pointCount >= AGGREGATION_THRESHOLD);
      } else {
        const newResults = results.map((result) => ({
          ...result,
          datapoints: ["%-akt", "%-norm"].includes(result.unit || "")
            ? (result.datapoints as DatapointAggregate[]).map((val: any) => {
                const average = val.average * 100;
                return {
                  ...val,
                  average,
                };
              })
            : result.datapoints,
        }));

        setData(clampMinMax(newResults, yAxisMin, yAxisMax));
        setAggregated(!pointCount || pointCount >= AGGREGATION_THRESHOLD);
      }
    };

    if (start && end) {
      get();
    }
  }, [
    end,
    start,
    timeseriesId,
    unit,
    client,
    maxTimeseriesId,
    minTimeseriesId,
    granularity,
    yAxisMin,
    yAxisMax,
  ]);

  useEffect(() => {
    const get = async () => {
      const GRANULARITY_TO_MS = {
        "5m": 5 * 60 * 1000,
        h: 60 * 60 * 1000,
        "24h": 24 * 60 * 60 * 1000,
      };

      const batches = _.chunk(data![0].datapoints, 100);

      const roomTimeseriesIds = (
        await Promise.all(
          batches.map(async (batch: any) => {
            const latest = await client.datapoints.retrieveLatest(
              batch.map((d: any) => {
                return {
                  id: timeseriesRoomId,
                  before:
                    d.timestamp.valueOf() + GRANULARITY_TO_MS[granularity],
                };
              }),
            );
            return latest.map((l: any) =>
              l.datapoints.length > 0 ? l.datapoints[0].value : undefined,
            );
          }),
        )
      )
        .flat()
        .map((id, index) => ({
          id,
          timestamp: data![0].datapoints[index].timestamp,
        }));

      // remove undefined and duplicates
      const idBatches: number[][] = _.chunk(
        Array.from(
          new Set(roomTimeseriesIds.map((t) => t.id).filter((id) => !!id)),
        ),
        100,
      );

      const roomIds = await idBatches.reduce<Promise<Record<number, string>>>(
        async (prev, batch) => {
          const timeseries = await client.timeseries.retrieve(
            batch.map((id: any) => ({ id })),
          );
          const assets = await client.assets.retrieve(
            timeseries.map((t: any) => ({ id: t.assetId })),
          );
          const parents = await client.assets.retrieve(
            assets.map((t: any) => ({ id: t.parentId })),
          );
          const last = await prev;
          return parents.reduce(
            (acc, asset: any, index) => ({
              ...acc,
              [batch[index]]: `${asset.name} ${asset.description || ""}`,
            }),
            last,
          );
        },
        Promise.resolve({}),
      );

      setRoomNames(
        roomTimeseriesIds.reduce(
          (acc: any, data: any) => ({
            ...acc,
            [data.timestamp.getTime()]: roomIds[data.id],
          }),
          {},
        ),
      );
    };

    if (
      data &&
      data[0].datapoints &&
      data[0].datapoints.length > 0 &&
      timeseriesRoomId
    ) {
      get();
    }
  }, [end, start, client, timeseriesRoomId, data, granularity]);

  const margin = useMemo(
    () => ({
      top: 10,
      right: 0,
      bottom: 40,
      left:
        (maxPixelLengthInStringArray(
          10,
          "Arial",
          Object.values(stateDescription || {}),
        ) || 20) + 20,
    }),
    [stateDescription],
  );

  const height = width * 0.29;
  const xMax = width - margin.left - margin.right;
  const yMax = height - margin.top - margin.bottom;

  const timeScale = useMemo(
    () =>
      scaleTime<number>({
        range: [0, xMax],
        domain: [start.valueOf(), end.valueOf()],
      }),
    [xMax, start, end],
  );

  const valueScale = useMemo(() => {
    if (data) {
      const values = [
        ...data
          .map((d) => d.datapoints)
          .flat()
          .map((point: any) => point.average),
        ...(aggregated
          ? data
              .map((d) => d.datapoints)
              .flat()
              .map((point: any) => point.max)
          : []),
        ...(aggregated
          ? data
              .map((d) => d.datapoints)
              .flat()
              .map((point: any) => point.min)
          : []),
      ];

      const calculatedMinThreshold =
        min !== undefined &&
        unit &&
        formatUnit(data[0].unit) !== formatUnit(unit)
          ? DERIVED_UNITS[data[0].unit || ""][unit].to(min)
          : min;

      const calculatedMaxThreshold =
        max !== undefined &&
        unit &&
        formatUnit(data[0].unit) !== formatUnit(unit)
          ? DERIVED_UNITS[data[0].unit || ""][unit].to(max)
          : max;

      const maxMinValues = [...values];

      if (calculatedMinThreshold !== undefined) {
        maxMinValues.push(calculatedMinThreshold);
      }

      if (calculatedMaxThreshold !== undefined) {
        maxMinValues.push(calculatedMaxThreshold);
      }

      const minimum = yAxisMin ?? Math.min(...maxMinValues);

      const maximum = yAxisMax ?? Math.max(...maxMinValues);

      return scaleLinear<number>({
        range: [yMax, 0],
        domain: [
          minimum - (maximum - minimum) * 0.1,
          maximum + (maximum - minimum) * 0.1,
        ],
        nice: true,
      });
    }
    return undefined;
  }, [data, yMax, min, max, unit, aggregated, yAxisMin, yAxisMax]);

  const handleTooltip = useCallback(
    (event: React.MouseEvent<SVGRectElement>) => {
      const getNumberFormat = (unit: string) =>
        new Intl.NumberFormat("nb-NO", {
          maximumFractionDigits: Math.max(getFractionDigits(unit), 2),
        });

      if (data && data[0].datapoints.length > 0) {
        const { x, y } = localPoint(event) || { x: 0, y: 0 };
        const date = timeScale.invert(x - margin.left);

        const datapointClosestToDate = getDatapointClosestToDate<{
          min: number;
          max: number;
          average: number;
          timestamp: Date;
        }>(data[0].datapoints, date);

        setMarker({
          y: valueScale!(datapointClosestToDate!.average),
          x: timeScale(datapointClosestToDate!.timestamp.getTime()),
        });

        showTooltip({
          tooltipData: (
            <div style={{ zIndex: 100000 }}>
              <Row gutter={[8, 8]}>
                <Col sm={24} style={{ lineHeight: "22px", textAlign: "right" }}>
                  <strong>
                    {dayjs(datapointClosestToDate!.timestamp).format(
                      "DD. MMMM HH:mm",
                    )}
                  </strong>
                </Col>
              </Row>
              <Row gutter={[8, 8]}>
                <Col style={{ lineHeight: "22px", textAlign: "right" }} sm={24}>
                  {!aggregated ||
                  datapointClosestToDate!.max === datapointClosestToDate!.min
                    ? stateDescription
                      ? stateDescription[
                          Math.round(datapointClosestToDate!.average)
                        ]
                      : `${getNumberFormat(data[0].unit).format(
                          datapointClosestToDate!.average,
                        )}${formatUnit(unit || data[0].unit)}`
                    : `Max: ${getNumberFormat(data[0].unit).format(
                        datapointClosestToDate!.max,
                      )}${formatUnit(
                        unit || data[0].unit,
                      )} Snitt: ${getNumberFormat(data[0].unit).format(
                        datapointClosestToDate!.average,
                      )}${formatUnit(
                        unit || data[0].unit,
                      )} Min: ${getNumberFormat(data[0].unit).format(
                        datapointClosestToDate!.min,
                      )}${formatUnit(unit || data[0].unit)}`}
                </Col>
              </Row>
              {roomNames &&
                roomNames[datapointClosestToDate!.timestamp.getTime()] && (
                  <Col
                    style={{ lineHeight: "22px", textAlign: "right" }}
                    sm={24}
                  >
                    {roomNames[datapointClosestToDate!.timestamp.getTime()]}
                  </Col>
                )}
            </div>
          ),
          tooltipLeft: x,
          tooltipTop: y - 30,
        });
      }
    },
    [
      showTooltip,
      valueScale,
      timeScale,
      data,
      margin,
      unit,
      roomNames,
      aggregated,
      stateDescription,
    ],
  );

  return (
    <div style={{ display: "inline-block", transform: "translate(0,-50%)" }}>
      <svg width={width} height={height} style={{ overflow: "visible" }}>
        <RectClipPath x={0} y={0} width={xMax + 10} height={yMax} id="clip" />
        <rect
          x={0}
          y={0}
          width={width}
          height={height}
          fill={themeContext.neutral9}
        />
        <Group left={margin.left} top={margin.top}>
          <MarkerCircle
            id="marker-circle"
            fill={themeContext.accent2}
            size={3}
            refX={2}
          />
          <MarkerCircle
            id="marker-circle-compare"
            fill={themeContext.neutral2}
            size={2}
            refX={1}
          />

          {valueScale && (
            <GridRows
              scale={valueScale}
              width={xMax}
              height={yMax}
              stroke={themeContext.neutral4}
            />
          )}
          <GridColumns
            scale={timeScale}
            width={xMax}
            height={yMax}
            stroke={themeContext.neutral4}
          />
          {valueScale && (
            <AxisLeft
              scale={valueScale}
              numTicks={5}
              tickStroke={themeContext.neutral4}
              tickFormat={
                stateDescription
                  ? formatMultiState(stateDescription)
                  : formatNonScientific
              }
              tickLabelProps={{
                fill: themeContext.neutral4,
                width: margin.left + 100,
              }}
            />
          )}
          <AxisBottom
            top={yMax}
            scale={timeScale}
            stroke={themeContext.neutral4}
            tickStroke={themeContext.neutral4}
            tickFormat={(date: any) => {
              const m = dayjs(date);
              if (m.hour() === 0 && m.minute() === 0 && m.dayOfYear() === 1) {
                return m.format("YYYY");
              } else if (m.hour() === 0 && m.minute() === 0 && m.date() === 1) {
                return m.format("MMM");
              } else if (m.hour() === 0 && m.minute() === 0) {
                return m.format("dd D.");
              }
              return m.format("HH:mm");
            }}
            tickLabelProps={{
              fill: themeContext.neutral4,
            }}
          />

          <g clipPath="url(#clip)" fill="green">
            {valueScale && typeof min === "number" && (
              <Line
                from={{
                  x: 0,
                  y: valueScale(
                    unit && formatUnit(data![0].unit) !== formatUnit(unit)
                      ? DERIVED_UNITS[data![0].unit || ""][unit].to(min)
                      : min,
                  ),
                }}
                to={{
                  x: width - margin.right - margin.left,
                  y: valueScale(
                    unit && formatUnit(data![0].unit) !== formatUnit(unit)
                      ? DERIVED_UNITS[data![0].unit || ""][unit].to(min)
                      : min,
                  ),
                }}
                stroke={themeContext.neutral2}
                strokeWidth={2}
                pointerEvents="none"
                strokeDasharray="5,2"
              />
            )}
            {valueScale && typeof max === "number" && data && (
              <Line
                from={{
                  x: 0,
                  y: valueScale(
                    unit && formatUnit(data[0].unit) !== formatUnit(unit)
                      ? DERIVED_UNITS[data[0].unit || ""][unit].to(max)
                      : max,
                  ),
                }}
                to={{
                  x: width - margin.right - margin.left,
                  y: valueScale(
                    unit && formatUnit(data[0].unit) !== formatUnit(unit)
                      ? DERIVED_UNITS[data[0].unit || ""][unit].to(max)
                      : max,
                  ),
                }}
                stroke={themeContext.neutral2}
                strokeWidth={2}
                pointerEvents="none"
                strokeDasharray="5,2"
              />
            )}
            {valueScale && data && data[1] && (
              <LinePath
                data={data[1].datapoints}
                x={(d: any) => timeScale(d.timestamp.getTime())}
                y={(d: any) => valueScale(d.average)}
                clipPathUnits="objectBoundingBox"
                clipRule="evenodd"
                stroke={themeContext.neutral2}
                strokeWidth={2}
              />
            )}
            {aggregated && valueScale && data && data[1] && (
              <Area
                data={data[1].datapoints}
                x={(d: any) => timeScale(d.timestamp.getTime())}
                y0={(d: any) => valueScale(d.min)}
                y1={(d: any) => valueScale(d.max)}
                fill={Color(themeContext.neutral2).alpha(0.5).string()}
              />
            )}
            {valueScale && data && data[0] && (
              <LinePath
                data={data[0].datapoints}
                x={(d: any) => timeScale(d.timestamp.getTime())}
                y={(d: any) => valueScale(d.average)}
                markerMid={aggregated ? undefined : "url(#marker-circle)"}
                clipPathUnits="objectBoundingBox"
                clipRule="evenodd"
                stroke={"#FFA600"}
                strokeWidth={aggregated ? 2 : 1}
              />
            )}
            {aggregated && valueScale && data && data[0] && (
              <Area
                data={data[0].datapoints}
                x={(d: any) => timeScale(d.timestamp.getTime())}
                y0={(d: any) => valueScale(d.min)}
                y1={(d: any) => valueScale(d.max)}
                fill={Color("#FFA600").alpha(0.5).string()}
              />
            )}
            {!aggregated && marker && (
              <circle
                cx={marker.x}
                cy={marker.y}
                r={4}
                fill="#FFA600"
                stroke="#666"
                strokeWidth={2}
                pointerEvents="none"
              />
            )}
            {aggregated && marker && (
              <Line
                from={{ x: marker.x, y: 0 }}
                to={{ x: marker.x, y: yMax }}
                strokeDasharray="3,3"
                stroke="#666"
                strokeWidth={2}
                pointerEvents="none"
              />
            )}
          </g>
          <Bar
            x={0}
            y={0}
            width={width - margin.left - margin.right}
            height={height - margin.top - margin.bottom}
            fill="transparent"
            onMouseMove={handleTooltip}
            onMouseLeave={() => {
              setMarker(undefined);
              hideTooltip();
            }}
          />
        </Group>
      </svg>
      {tooltipOpen && (
        <TooltipWithBounds
          // set this to random so that it correctly updates with parent bounds
          key={Math.random()}
          top={tooltipTop}
          left={tooltipLeft}
          style={{ ...defaultStyles }}
        >
          {tooltipData}
        </TooltipWithBounds>
      )}
    </div>
  );
};

type Props = {
  timeseriesId: number;
  min?: number;
  max?: number;
  minTimeseriesId?: number;
  maxTimeseriesId?: number;
  start: Dayjs;
  end: Dayjs;
  unit: string;
  timeseriesRoomId?: number;
  yAxisMin?: number;
  yAxisMax?: number;
  stateDescription?: Record<number, string>;
};
export const Graph = ({
  timeseriesId,
  min,
  max,
  minTimeseriesId,
  maxTimeseriesId,
  start,
  end,
  unit,
  timeseriesRoomId,
  yAxisMin,
  yAxisMax,
  stateDescription,
}: Props) => {
  return (
    <AutoSizer>
      {({ width, height }: Size) =>
        start &&
        end && (
          <InnerGraph
            timeseriesId={timeseriesId}
            height={height}
            width={width}
            start={start}
            end={end}
            min={min}
            max={max}
            minTimeseriesId={minTimeseriesId}
            maxTimeseriesId={maxTimeseriesId}
            unit={unit}
            timeseriesRoomId={timeseriesRoomId}
            yAxisMin={yAxisMin}
            yAxisMax={yAxisMax}
            stateDescription={stateDescription}
          />
        )
      }
    </AutoSizer>
  );
};
