import { Box } from '@chakra-ui/react'
import { addDays } from 'date-fns'
import { TooltipMetricRow } from 'features/dashboard/components/ChartTooltip/shared/TooltipMetricRow'
import { TooltipSectionLabel } from 'features/dashboard/components/ChartTooltip/shared/TooltipSectionLabel'
import { type InferenceGraphPointFieldsFragment } from 'generated/graphql/graphql'
import { METRIC_FORMAT } from 'graphql/statistics/constants'
import { useNormalizedMetrics } from 'graphql/statistics/useMetrics'
import { useMerchantInfo } from 'graphql/useMerchantInfo'
import Highcharts from 'highcharts'
import HighchartsReact from 'highcharts-react-official'
import type React from 'react'
import { renderToString } from 'react-dom/server'
import { colorTheme } from 'ui/theme/colors'
import { CHART_TYPE_ID } from 'utils/chart/chartTypes'
import { CHART_PRIMARY_COLOR, staticChartOptions } from 'utils/chart/constants'
import { formatMetric } from 'utils/numberFormats'
import { MARGIN_DAYS } from './KPIChart'

interface CausalEffectChartProps {
  cumulativeLift: InferenceGraphPointFieldsFragment[]
  targetVariable: string
  startDate: string
  endDate: string
  treatmentPeriod: number
  postTreatmentPeriod: number
}

export const CausalEffectChart: React.FC<CausalEffectChartProps> = ({
  cumulativeLift,
  targetVariable,
  startDate,
  endDate,
  postTreatmentPeriod,
}) => {
  const { currency } = useMerchantInfo()
  const normalizedMetrics = useNormalizedMetrics()

  const metricLabel = normalizedMetrics[targetVariable]?.label ?? targetVariable

  const postTreatmentEndDate = addDays(new Date(endDate), postTreatmentPeriod)

  const treatmentData = cumulativeLift.map((point) => ({
    x: new Date(point.date).getTime(),
    y: point.value,
    lowerBound: point.lowerBound === null ? undefined : point.lowerBound,
    upperBound: point.upperBound === null ? undefined : point.upperBound,
  }))

  const options: Highcharts.Options = {
    ...staticChartOptions,
    title: {
      text: 'Causal effect',
      align: 'left',
      floating: true,
      style: {
        fontWeight: '500',
        color: colorTheme.grey[800],
        fontSize: '14',
      },
    },
    chart: {
      ...staticChartOptions.chart,
      zooming: { type: 'xy' },
      type: CHART_TYPE_ID.LINE,
      height: 300,
      marginTop: 60,
    },
    xAxis: {
      ...staticChartOptions.xAxis,
      type: 'datetime',
      max: addDays(postTreatmentEndDate, MARGIN_DAYS).getTime(),
      plotBands: [
        {
          from: new Date(startDate).getTime(),
          to: new Date(endDate).getTime(),
          color: colorTheme.gray[100],
          label: {
            text: 'Test period',
            style: {
              color: colorTheme.black,
              fontSize: '10px',
            },
          },
        },
        {
          from: new Date(endDate).getTime(),
          to: postTreatmentEndDate.getTime(),
          color: colorTheme.gray[50],
          label: {
            text: 'Post-treatment',
            style: {
              color: colorTheme.black,
              fontSize: '10px',
            },
          },
        },
      ],
      plotLines: [
        {
          value: new Date(startDate).getTime(),
          color: colorTheme.gray[300],
          dashStyle: 'ShortDash',
          width: 1,
        },
        {
          value: new Date(endDate).getTime(),
          color: colorTheme.gray[300],
          dashStyle: 'ShortDash',
          width: 1,
        },
      ],
    },
    yAxis: {
      ...staticChartOptions.yAxis,
      title: {
        ...staticChartOptions.yAxis?.title,
        text: `${metricLabel} (${currency})`,
      },
      plotLines: [
        {
          value: 0,
          color: colorTheme.gray[300],
          dashStyle: 'ShortDot',
          width: 1,
        },
      ],
    },
    series: [
      {
        type: 'arearange',
        name: 'Confidence interval',
        data: treatmentData.map(({ x, lowerBound, upperBound }) => ({
          x,
          low: lowerBound,
          high: upperBound,
        })),
        color: colorTheme.gray[400],
        fillColor: colorTheme.gray[200],
        lineWidth: 1,
        dashStyle: 'Dot',
        // enableMouseTracking: false, // Disable tooltip for the confidence interval
        showInLegend: false,
      },
      {
        type: 'line',
        name: `Cumulative causal effect on ${metricLabel}`,
        data: treatmentData,
        color: CHART_PRIMARY_COLOR,
        marker: {
          enabled: false,
        },
      },
    ],
    legend: {
      align: 'center',
      verticalAlign: 'top',
      layout: 'horizontal',
    },
    plotOptions: {
      ...staticChartOptions.plotOptions,
      series: {
        ...staticChartOptions.plotOptions.series,
        stickyTracking: true,
      },
    },
    tooltip: {
      ...staticChartOptions.tooltip,
      shared: true,
      useHTML: true,
      formatter: function () {
        if (!this.x) return ''

        const date = new Date(this.x)
        const formattedDate = date.toLocaleDateString('en-US', {
          month: 'short',
          day: 'numeric',
          year: 'numeric',
        })

        const [areaRangePoint, valuePoint] = this.points ?? []

        const hasAreaRangePoint = Boolean(
          areaRangePoint &&
            areaRangePoint.point.high &&
            areaRangePoint.point.low,
        )

        const formattedHighValue = areaRangePoint.point.high
          ? formatMetric(
              METRIC_FORMAT.CURRENCY,
              areaRangePoint.point.high,
              currency,
            )
          : 'N/A'

        const formattedLowValue = areaRangePoint.point.low
          ? formatMetric(
              METRIC_FORMAT.CURRENCY,
              areaRangePoint.point.low,
              currency,
            )
          : 'N/A'

        const value = valuePoint.y
          ? formatMetric(METRIC_FORMAT.CURRENCY, valuePoint.y, currency)
          : '0'

        const element = (
          <div>
            <TooltipSectionLabel label={formattedDate} />

            {hasAreaRangePoint && (
              <TooltipMetricRow
                iconColor={areaRangePoint.color?.toString()}
                metricName="Upper"
                value={formattedHighValue}
              />
            )}

            <TooltipMetricRow
              iconColor={valuePoint.color?.toString()}
              metricName={`${metricLabel}`}
              value={value}
            />

            {hasAreaRangePoint && (
              <TooltipMetricRow
                iconColor={areaRangePoint.color?.toString()}
                metricName="Lower"
                value={formattedLowValue}
              />
            )}
          </div>
        )

        return renderToString(element)
      },
    },
  }

  return (
    <Box px={6} py={5} position="relative">
      <HighchartsReact highcharts={Highcharts} options={options} />
    </Box>
  )
}
