import styles from "./DataTable.module.css";
import { Grid, Pagination, Table } from "@intility/bifrost-react";
import type { BreakpointValue } from "@intility/bifrost-react/Breakpoint";
import {
  type ColumnDef,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  getFilteredRowModel,
  getPaginationRowModel,
  getSortedRowModel,
  type InitialTableState,
  type OnChangeFn,
  type PaginationState,
  type RowData,
  type TableMeta,
  type TableState,
  useReactTable,
} from "@tanstack/react-table";
import type { ReactElement } from "react";
import { cn } from "~/utils/clsx";
import { DataTableSkeletonRows } from "./DataTableSkeleton";

const sortDirections = { asc: "asc", desc: "desc" } as const;

declare module "@tanstack/table-core" {
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
  interface TableMeta<TData extends RowData> {
    /**
     * Amount of skeleton rows to display. Defaults to the page size.
     */
    skeletonRows?: number;
  }

  // eslint-disable-next-line @typescript-eslint/no-unused-vars
  interface ColumnMeta<TData extends RowData, TValue> {
    /**
     * The breakpoint the column should start appearing at
     */
    fromSize?: BreakpointValue | null;
    /**
     * Class names to set on the cell
     */
    cellClassName?: string;
    /**
     * Class names to set on the header
     */
    headerClassName?: string;
  }
}

type BaseDataTableProps<TData> = {
  className?: string;
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
  columns: ColumnDef<TData, any>[];
  data: TData[] | undefined;
  onPaginationChange?: OnChangeFn<PaginationState>;
  initialState?: InitialTableState;
  state?: Partial<TableState>;
  meta?: TableMeta<TData>;
  isLoading?: boolean;
  noBorder?: boolean;
  noHeader?: boolean;
  noExpandPadding?: boolean;
  noResultsText?: string;
  onRowClick?: (rowData: TData) => void;
  limitExpandClick?: boolean;
  ExpandedRowComponent?: (props: { rowData: TData }) => ReactElement;
};

type DataTableProps<TData> = BaseDataTableProps<TData>;

const fallbackData: never[] = [];

export const DataTable = <TData,>({
  className,
  columns,
  data,
  onPaginationChange,
  initialState,
  state,
  meta,
  isLoading,
  onRowClick,
  noBorder,
  noHeader,
  noExpandPadding,
  noResultsText = "No results found.",
  limitExpandClick,
  ExpandedRowComponent,
}: DataTableProps<TData>) => {
  const table = useReactTable({
    data: data ?? fallbackData,
    columns,
    getCoreRowModel: getCoreRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getPaginationRowModel: state?.pagination
      ? getPaginationRowModel()
      : undefined,
    getExpandedRowModel: getExpandedRowModel(),
    onPaginationChange,
    state,
    initialState,
    meta,
  });

  const isExpandingEnabled = Boolean(ExpandedRowComponent);
  const tableExpandedState = table.getState().expanded;

  const expandAllRows =
    typeof tableExpandedState === "boolean" && tableExpandedState;

  const headerGroups = table.getHeaderGroups();
  const rows = table.getRowModel().rows;

  return (
    <Grid gap={24}>
      <Table noBorder={noBorder} className={className}>
        {!noHeader && (
          <Table.Header>
            {headerGroups.map((headerGroup) => (
              <Table.Row key={headerGroup.id}>
                {isExpandingEnabled && !isLoading && <Table.HeaderCell />}

                {headerGroup.headers.map((header) => {
                  const column = header.column;
                  const columnMeta = column.columnDef.meta;
                  const isSortable = column.getCanSort();
                  const isSorted = column.getIsSorted();

                  const sortDirection = !isSorted
                    ? "none"
                    : sortDirections[isSorted];

                  const { headerClassName, fromSize } = columnMeta ?? {};

                  return (
                    <Table.HeaderCell
                      className={cn(headerClassName, {
                        [`from-${fromSize}`]: fromSize,
                      })}
                      key={header.id}
                      sorting={isSortable ? sortDirection : undefined}
                      onClick={header.column.getToggleSortingHandler()}
                    >
                      {flexRender(
                        header.column.columnDef.header,
                        header.getContext(),
                      )}
                    </Table.HeaderCell>
                  );
                })}
              </Table.Row>
            ))}
          </Table.Header>
        )}

        <Table.Body>
          {isLoading ? (
            <DataTableSkeletonRows table={table} />
          ) : rows.length > 0 ? (
            <>
              {rows.map((row) => {
                const isRowExpanded =
                  expandAllRows || Boolean(tableExpandedState[row.id]);

                return (
                  <Table.Row
                    key={row.id}
                    open={isRowExpanded}
                    onClick={onRowClick && (() => onRowClick(row.original))}
                    onOpenChange={() => row.toggleExpanded()}
                    limitExpandClick={limitExpandClick}
                    content={
                      ExpandedRowComponent ? (
                        <div
                          className={cn({
                            "bfl-padding": !noExpandPadding,
                          })}
                        >
                          <ExpandedRowComponent rowData={row.original} />
                        </div>
                      ) : null
                    }
                  >
                    {row.getVisibleCells().map((cell) => {
                      const column = cell.column;
                      const columnMeta = column.columnDef.meta;

                      const cellClassName = columnMeta?.cellClassName;
                      const sizeClassName = columnMeta?.fromSize;

                      return (
                        <Table.Cell
                          key={cell.id}
                          className={cn(cellClassName, {
                            [`from-${sizeClassName}`]: sizeClassName,
                          })}
                        >
                          {flexRender(
                            cell.column.columnDef.cell,
                            cell.getContext(),
                          )}
                        </Table.Cell>
                      );
                    })}
                  </Table.Row>
                );
              })}
            </>
          ) : (
            <Table.Row>
              <Table.Cell
                className={styles.noResultsCell}
                colSpan={999}
                align="center"
              >
                {noResultsText}
              </Table.Cell>
            </Table.Row>
          )}
        </Table.Body>
      </Table>

      {state?.pagination && (
        <Pagination
          totalPages={table.getPageCount()}
          currentPage={table.getState().pagination.pageIndex + 1}
          onChange={(newPageNumber) => table.setPageIndex(newPageNumber - 1)}
        />
      )}
    </Grid>
  );
};
