import {
  CompareValueType,
  getOptimalOptimizationMetric,
  optimizationsMetricValuesMap,
  type MetricOptimizationTableColumn,
  type OPTIMIZATION_METRIC,
} from 'features/optimizations/consts'
import { type ContributionData } from 'features/optimizations/graphql/useHistoricalAnalysisQuery'
import { getChannelId } from 'features/optimizations/utils/transformChannel'
import { calcPercentageDiff } from 'utils/calcPercentageDiff'
import { derivedMetrics, totalKey } from './consts'

interface GetAggregatedDataProps {
  contribution: ContributionData | undefined
  columns: MetricOptimizationTableColumn[]
}

type AggregationRow = Record<OPTIMIZATION_METRIC, number>
type AggregationResult = Record<string, AggregationRow>

const initializeAggregation = (
  accumulator: AggregationResult,
  aggregationKey: string,
  rowKey: OPTIMIZATION_METRIC,
) => {
  const nextAggregation = accumulator[aggregationKey] ?? {}

  nextAggregation[rowKey] = 0
  nextAggregation[getOptimalOptimizationMetric(rowKey)] = 0

  return nextAggregation
}

const aggregateRow = (
  accumulator: AggregationResult,
  row: ContributionData[number],
  columns: MetricOptimizationTableColumn[],
) => {
  const id = getChannelId(row.channel)

  columns.forEach(({ key, excludeUnoptimizableChannels }) => {
    const shouldExcludeValue =
      excludeUnoptimizableChannels && isUnoptimizableChannel(row)

    if (!shouldExcludeValue) {
      const optimalKey = getOptimalOptimizationMetric(key)
      const value = optimizationsMetricValuesMap[key].getValue(row)
      const optimalValue =
        optimizationsMetricValuesMap[optimalKey].getValue(row)

      accumulator[totalKey][key] += isFinite(value) ? value : 0
      accumulator[id][key] += isFinite(value) ? value : 0
      accumulator[totalKey][optimalKey] += isFinite(optimalValue)
        ? optimalValue
        : 0
      accumulator[id][optimalKey] += isFinite(optimalValue) ? optimalValue : 0
    }
  })
}

export const getAggregatedData = ({
  contribution,
  columns,
}: GetAggregatedDataProps) => {
  // Initialize the aggregated data
  const calculatedData = columns.reduce((acc, { key }) => {
    acc[totalKey] = initializeAggregation(acc, totalKey, key)
    contribution?.forEach(({ channel }) => {
      const id = getChannelId(channel)

      acc[id] = initializeAggregation(acc, id, key)
    })

    return acc
  }, {} as AggregationResult)

  contribution?.forEach((row) => {
    aggregateRow(calculatedData, row, columns)
  })

  // Derived metrics need to be calculated separately since they depend on other metrics
  derivedMetrics.forEach((metric) => {
    calculatedData[totalKey][metric] = optimizationsMetricValuesMap[
      metric
    ].getValue(calculatedData[totalKey])
    contribution?.forEach(({ channel }) => {
      const id = getChannelId(channel)

      calculatedData[id][metric] = optimizationsMetricValuesMap[
        metric
      ].getValue(calculatedData[id])
    })
  })

  return calculatedData
}

export const getCompareValue = (
  type: CompareValueType,
  previousValue: number,
  currentValue: number,
) => {
  if (type === CompareValueType.Percentage) {
    return calcPercentageDiff(currentValue, previousValue) / 100
  }

  return currentValue - previousValue
}

const unoptimizableChannelId = 'unoptimizable'

export const isUnoptimizableChannel = (channel?: ContributionData[number]) => {
  return channel?.spendAnalysis?.id === unoptimizableChannelId
}
