import { findParentNode } from '@tiptap/core';
import { CellSelection, TableMap } from '@tiptap/pm/tables';

export const isRectSelected = rect => selection => {
  const map = TableMap.get(selection.$anchorCell.node(-1));
  const start = selection.$anchorCell.start(-1);
  const cells = map.cellsInRect(rect);
  const selectedCells = map.cellsInRect(
    map.rectBetween(selection.$anchorCell.pos - start, selection.$headCell.pos - start)
  );

  return cells.every(cell => selectedCells.includes(cell));
};

export const findTable = selection => findParentNode(
  node => node.type.spec.tableRole === 'table'
)(selection);

export const isCellSelection = selection => selection instanceof CellSelection;

export const isColumnSelected = columnIndex => selection => {
  if (isCellSelection(selection)) {
    const map = TableMap.get(selection.$anchorCell.node(-1));
    return isRectSelected({
      left: columnIndex,
      right: columnIndex + 1,
      top: 0,
      bottom: map.height,
    })(selection);
  }
  return false;
};

export const isRowSelected = rowIndex => selection => {
  if (isCellSelection(selection)) {
    const map = TableMap.get(selection.$anchorCell.node(-1));
    return isRectSelected({
      left: 0,
      right: map.width,
      top: rowIndex,
      bottom: rowIndex + 1,
    })(selection);
  }
  return false;
};

export const isTableSelected = selection => {
  if (isCellSelection(selection)) {
    const map = TableMap.get(selection.$anchorCell.node(-1));
    return isRectSelected({
      left: 0,
      right: map.width,
      top: 0,
      bottom: map.height,
    })(selection);
  }
  return false;
};

export const getCellsInColumn = columnIndex => selection => {
  const table = findTable(selection);
  if (!table) return null;

  const map = TableMap.get(table.node);
  const indexes = Array.isArray(columnIndex) ? columnIndex : [columnIndex];

  return indexes.flatMap(index => {
    if (index < 0 || index > map.width - 1) return [];

    const cells = map.cellsInRect({
      left: index,
      right: index + 1,
      top: 0,
      bottom: map.height,
    });

    return cells.map(nodePos => ({
      pos: nodePos + table.start,
      start: nodePos + table.start + 1,
      node: table.node.nodeAt(nodePos),
    }));
  });
};

export const getCellsInRow = rowIndex => selection => {
  const table = findTable(selection);
  if (!table) return null;

  const map = TableMap.get(table.node);
  const indexes = Array.isArray(rowIndex) ? rowIndex : [rowIndex];

  return indexes.flatMap(index => {
    if (index < 0 || index > map.height - 1) return [];

    const cells = map.cellsInRect({
      left: 0,
      right: map.width,
      top: index,
      bottom: index + 1,
    });

    return cells.map(nodePos => ({
      pos: nodePos + table.start,
      start: nodePos + table.start + 1,
      node: table.node.nodeAt(nodePos),
    }));
  });
};

export const getCellsInTable = selection => {
  const table = findTable(selection);
  if (!table) return null;

  const map = TableMap.get(table.node);
  const cells = map.cellsInRect({
    left: 0,
    right: map.width,
    top: 0,
    bottom: map.height,
  });

  return cells.map(nodePos => ({
    pos: nodePos + table.start,
    start: nodePos + table.start + 1,
    node: table.node.nodeAt(nodePos),
  }));
};

export const findParentNodeClosestToPos = ($pos, predicate) => {
  for (let i = $pos.depth; i > 0; i -= 1) {
    const node = $pos.node(i);
    if (predicate(node)) {
      return {
        pos: i > 0 ? $pos.before(i) : 0,
        start: $pos.start(i),
        depth: i,
        node,
      };
    }
  }
  return null;
};

export const findCellClosestToPos = $pos => findParentNodeClosestToPos(
  $pos,
  node => node.type.spec.tableRole && /cell/i.test(node.type.spec.tableRole)
);

const select = type => index => tr => {
  const table = findTable(tr.selection);
  if (!table) return tr;

  const map = TableMap.get(table.node);
  const isRowSelection = type === 'row';
  const left = isRowSelection ? 0 : index;
  const top = isRowSelection ? index : 0;
  const right = isRowSelection ? map.width : index + 1;
  const bottom = isRowSelection ? index + 1 : map.height;

  if (index < 0 || index >= (isRowSelection ? map.height : map.width)) return tr;

  const cellsInFirstRow = map.cellsInRect({ left, top, right, bottom: top + 1 });
  const cellsInLastRow = bottom - top === 1 ? cellsInFirstRow : map.cellsInRect({
    left,
    top: bottom - 1,
    right,
    bottom,
  });

  const head = table.start + cellsInFirstRow[0];
  const anchor = table.start + cellsInLastRow[cellsInLastRow.length - 1];
  const $head = tr.doc.resolve(head);
  const $anchor = tr.doc.resolve(anchor);

  return tr.setSelection(new CellSelection($anchor, $head));
};

export const selectColumn = select('column');
export const selectRow = select('row');

export const selectTable = tr => {
  const table = findTable(tr.selection);
  if (!table) return tr;

  const { map } = TableMap.get(table.node);
  if (!map || !map.length) return tr;

  const head = table.start + map[0];
  const anchor = table.start + map[map.length - 1];
  const $head = tr.doc.resolve(head);
  const $anchor = tr.doc.resolve(anchor);

  return tr.setSelection(new CellSelection($anchor, $head));
};
