import {
  type ExtendedSortDirection,
  type TableState,
} from 'graphql/reports/types'
import { useGetAtomFamilyValue } from 'hooks/useGetAtomValue'
import { type SetStateAction, atom, useAtomValue, useSetAtom } from 'jotai'
import { atomFamily } from 'jotai/utils'
import { focusAtom } from 'jotai-optics'
import { useCallback, useMemo } from 'react'
import { widgetAnalyticsConfigAtom } from '../atoms/dashboardViewState'
import { widgetDateStateAtom } from './useWidgetDateState'
import { widgetDimensionsStateAtom } from './useWidgetDimensionsState'
import { widgetMetricsStateAtom } from './useWidgetMetricState'

const getDefaultColumnState = (metricKeys: string[], dimensionKeys: string[]) =>
  dimensionKeys.concat(metricKeys).map(
    (key) =>
      ({
        id: key,
        isPinned: false,
        sort: metricKeys[0] === key ? 'desc' : null,
      }) as const,
  )

export const getValidTableState = ({
  tableState,
  metricKeys,
  dimensionKeys,
  isCompare,
}: {
  tableState: TableState | null | undefined
  metricKeys: string[]
  dimensionKeys: string[]
  isCompare: boolean
}): TableState => {
  if (!tableState || tableState.length === 0) {
    return getDefaultColumnState(metricKeys, dimensionKeys)
  }

  const newTableState = [...tableState]

  const dimensionsSet = new Set(dimensionKeys)
  const metricsSet = new Set(metricKeys)

  const dimensionsAndMetrics = [...dimensionKeys, ...metricKeys]

  const columnIds = newTableState.map((column) => column.id)
  const columnIdsSet = new Set(columnIds)
  const missingIds = dimensionsAndMetrics.filter((id) => !columnIdsSet.has(id))

  // insert the missing ids into state in the correct order so that the missing id comes right behind the closest id to the left of it from the metricKeys array
  missingIds.forEach((id) => {
    const indexOfId = dimensionsAndMetrics.indexOf(id)
    let leftMostIndex = 0

    for (let i = indexOfId - 1; i >= 0; i--) {
      const pivotId = dimensionsAndMetrics[i]

      if (columnIdsSet.has(pivotId)) {
        leftMostIndex = columnIds.indexOf(pivotId) + 1
        break
      }
    }

    newTableState.splice(leftMostIndex, 0, { id, sort: null, isPinned: false })
    columnIds.splice(leftMostIndex, 0, id)
    columnIdsSet.add(id)
  })

  // Keep only selected metrics or dimensions
  let validColumns = newTableState.filter(
    (column) => metricsSet.has(column.id) || dimensionsSet.has(column.id),
  )

  if (!isCompare) {
    // Remove compare sort keys
    validColumns = validColumns.map((column) => ({
      ...column,
      sort: ['ascCompare', 'descCompare'].includes(column?.sort ?? '')
        ? null
        : (column?.sort ?? null),
    }))
  }

  return validColumns
}

const focusWidgetTableStateStateAtom = atomFamily(
  (widgetId: string | undefined) =>
    focusAtom(widgetAnalyticsConfigAtom(widgetId), (optic) =>
      optic.optional().prop('tableState'),
    ),
)

const widgetTableStateAtom = atomFamily((widgetId: string | undefined) =>
  atom(
    (get) => {
      const tableState = get(focusWidgetTableStateStateAtom(widgetId))
      const { dimensionKeys } = get(widgetDimensionsStateAtom(widgetId))
      const { metricKeys } = get(widgetMetricsStateAtom(widgetId))
      const {
        resolvedDateState: { isCompare },
      } = get(widgetDateStateAtom(widgetId))

      return getValidTableState({
        tableState,
        dimensionKeys,
        metricKeys,
        isCompare,
      })
    },
    (_, set, tableState: SetStateAction<TableState>) => {
      set(focusWidgetTableStateStateAtom(widgetId), tableState)
    },
  ),
)

export const useWidgetTableState = (widgetId: string | undefined) => {
  const atom = useMemo(() => widgetTableStateAtom(widgetId), [widgetId])

  return useAtomValue(atom)
}

export const useSetWidgetTableState = (widgetId: string | undefined) => {
  const setTableState = useSetAtom(widgetTableStateAtom(widgetId))
  const getCurrentTableState = useGetAtomFamilyValue(widgetTableStateAtom)

  const setTableSorting = useCallback(
    (columnId: string, sortOrder: ExtendedSortDirection) => {
      const setColumnSorting = (
        tableState: TableState,
        columnId: string,
        sortOrder: ExtendedSortDirection,
      ) => {
        return tableState.map((column) => ({
          ...column,
          sort: columnId === column.id ? sortOrder : null,
        }))
      }

      const currentTableState = getCurrentTableState(widgetId)

      const newTableState = setColumnSorting(
        currentTableState,
        columnId,
        sortOrder,
      )

      setTableState(newTableState)
    },
    [getCurrentTableState, setTableState, widgetId],
  )

  return {
    setTableSorting,
  }
}
