1 import {BaseSelection, LexicalEditor} from "lexical";
6 $isTableSelection, TableCellNode, TableNode,
9 } from "@lexical/table";
10 import {$getParentOfType} from "./nodes";
11 import {$getNodeFromSelection} from "./selection";
12 import {el, formatSizeValue} from "./dom";
13 import {TableMap} from "./table-map";
15 function $getTableFromCell(cell: TableCellNode): TableNode|null {
16 return $getParentOfType(cell, $isTableNode) as TableNode|null;
19 export function getTableColumnWidths(table: HTMLTableElement): string[] {
20 const maxColRow = getMaxColRowFromTable(table);
22 const colGroup = table.querySelector('colgroup');
23 let widths: string[] = [];
24 if (colGroup && (colGroup.childElementCount === maxColRow?.childElementCount || !maxColRow)) {
25 widths = extractWidthsFromRow(colGroup);
27 if (widths.filter(Boolean).length === 0 && maxColRow) {
28 widths = extractWidthsFromRow(maxColRow);
34 function getMaxColRowFromTable(table: HTMLTableElement): HTMLTableRowElement | null {
35 const rows = table.querySelectorAll('tr');
36 let maxColCount: number = 0;
37 let maxColRow: HTMLTableRowElement | null = null;
39 for (const row of rows) {
40 if (row.childElementCount > maxColCount) {
42 maxColCount = row.childElementCount;
49 function extractWidthsFromRow(row: HTMLTableRowElement | HTMLTableColElement) {
50 return [...row.children].map(child => extractWidthFromElement(child as HTMLElement))
53 function extractWidthFromElement(element: HTMLElement): string {
54 let width = element.style.width || element.getAttribute('width');
55 if (width && !Number.isNaN(Number(width))) {
62 export function $setTableColumnWidth(node: TableNode, columnIndex: number, width: number|string): void {
63 const rows = node.getChildren() as TableRowNode[];
65 for (const row of rows) {
66 const cellCount = row.getChildren().length;
67 if (cellCount > maxCols) {
72 let colWidths = node.getColWidths();
73 if (colWidths.length === 0 || colWidths.length < maxCols) {
74 colWidths = Array(maxCols).fill('');
77 if (columnIndex + 1 > colWidths.length) {
78 console.error(`Attempted to set table column width for column [${columnIndex}] but only ${colWidths.length} columns found`);
81 colWidths[columnIndex] = formatSizeValue(width);
82 node.setColWidths(colWidths);
85 export function $getTableColumnWidth(editor: LexicalEditor, node: TableNode, columnIndex: number): number {
86 const colWidths = node.getColWidths();
87 if (colWidths.length > columnIndex && colWidths[columnIndex].endsWith('px')) {
88 return Number(colWidths[columnIndex].replace('px', ''));
91 // Otherwise, get from table element
92 const table = editor.getElementByKey(node.__key) as HTMLTableElement | null;
94 const maxColRow = getMaxColRowFromTable(table);
95 if (maxColRow && maxColRow.children.length > columnIndex) {
96 const cell = maxColRow.children[columnIndex];
97 return cell.clientWidth;
104 function $getCellColumnIndex(node: TableCellNode): number {
105 const row = node.getParent();
106 if (!$isTableRowNode(row)) {
111 const cells = row.getChildren<TableCellNode>();
112 for (const cell of cells) {
113 let colSpan = cell.getColSpan() || 1;
115 if (cell.getKey() === node.getKey()) {
123 export function $setTableCellColumnWidth(cell: TableCellNode, width: string): void {
124 const table = $getTableFromCell(cell)
125 const index = $getCellColumnIndex(cell);
127 if (table && index >= 0) {
128 $setTableColumnWidth(table, index, width);
132 export function $getTableCellColumnWidth(editor: LexicalEditor, cell: TableCellNode): string {
133 const table = $getTableFromCell(cell)
134 const index = $getCellColumnIndex(cell);
139 const widths = table.getColWidths();
140 return (widths.length > index) ? widths[index] : '';
143 export function buildColgroupFromTableWidths(colWidths: string[]): HTMLElement|null {
144 if (colWidths.length === 0) {
148 const colgroup = el('colgroup');
149 for (const width of colWidths) {
150 const col = el('col');
152 col.style.width = width;
154 colgroup.append(col);
160 export function $getTableCellsFromSelection(selection: BaseSelection|null): TableCellNode[] {
161 if ($isTableSelection(selection)) {
162 const nodes = selection.getNodes();
163 return nodes.filter(n => $isTableCellNode(n));
166 const cell = $getNodeFromSelection(selection, $isTableCellNode) as TableCellNode;
167 return cell ? [cell] : [];
170 export function $mergeTableCellsInSelection(selection: TableSelection): void {
171 const selectionShape = selection.getShape();
172 const cells = $getTableCellsFromSelection(selection);
173 if (cells.length === 0) {
177 const table = $getTableFromCell(cells[0]);
182 const tableMap = new TableMap(table);
183 const headCell = tableMap.getCellAtPosition(selectionShape.toX, selectionShape.toY);
188 // We have to adjust the shape since it won't take into account spans for the head corner position.
189 const fixedToX = selectionShape.toX + ((headCell.getColSpan() || 1) - 1);
190 const fixedToY = selectionShape.toY + ((headCell.getRowSpan() || 1) - 1);
192 const mergeCells = tableMap.getCellsInRange({
193 fromX: selectionShape.fromX,
194 fromY: selectionShape.fromY,
199 if (mergeCells.length === 0) {
203 const firstCell = mergeCells[0];
204 const newWidth = Math.abs(selectionShape.fromX - fixedToX) + 1;
205 const newHeight = Math.abs(selectionShape.fromY - fixedToY) + 1;
207 for (let i = 1; i < mergeCells.length; i++) {
208 const mergeCell = mergeCells[i];
209 firstCell.append(...mergeCell.getChildren());
213 firstCell.setColSpan(newWidth);
214 firstCell.setRowSpan(newHeight);
217 export function $getTableRowsFromSelection(selection: BaseSelection|null): TableRowNode[] {
218 const cells = $getTableCellsFromSelection(selection);
219 const rowsByKey: Record<string, TableRowNode> = {};
220 for (const cell of cells) {
221 const row = cell.getParent();
222 if ($isTableRowNode(row)) {
223 rowsByKey[row.getKey()] = row;
227 return Object.values(rowsByKey);
230 export function $getTableFromSelection(selection: BaseSelection|null): TableNode|null {
231 const cells = $getTableCellsFromSelection(selection);
232 if (cells.length === 0) {
236 const table = $getParentOfType(cells[0], $isTableNode);
237 if ($isTableNode(table)) {
244 export function $clearTableSizes(table: TableNode): void {
245 table.setColWidths([]);
247 // TODO - Extra form things once table properties and extra things
250 for (const row of table.getChildren()) {
251 if (!$isTableRowNode(row)) {
255 const rowStyles = row.getStyles();
256 rowStyles.delete('height');
257 rowStyles.delete('width');
258 row.setStyles(rowStyles);
260 const cells = row.getChildren().filter(c => $isTableCellNode(c));
261 for (const cell of cells) {
262 const cellStyles = cell.getStyles();
263 cellStyles.delete('height');
264 cellStyles.delete('width');
265 cell.setStyles(cellStyles);
271 export function $clearTableFormatting(table: TableNode): void {
272 table.setColWidths([]);
273 table.setStyles(new Map);
275 for (const row of table.getChildren()) {
276 if (!$isTableRowNode(row)) {
280 row.setStyles(new Map);
282 const cells = row.getChildren().filter(c => $isTableCellNode(c));
283 for (const cell of cells) {
284 cell.setStyles(new Map);
291 * Perform the given callback for each cell in the given table.
292 * Returning false from the callback stops the function early.
294 export function $forEachTableCell(table: TableNode, callback: (c: TableCellNode) => void|false): void {
295 outer: for (const row of table.getChildren()) {
296 if (!$isTableRowNode(row)) {
299 const cells = row.getChildren();
300 for (const cell of cells) {
301 if (!$isTableCellNode(cell)) {
304 const result = callback(cell);
305 if (result === false) {
312 export function $getCellPaddingForTable(table: TableNode): string {
313 let padding: string|null = null;
315 $forEachTableCell(table, (cell: TableCellNode) => {
316 const cellPadding = cell.getStyles().get('padding') || ''
317 if (padding === null) {
318 padding = cellPadding;
321 if (cellPadding !== padding) {
327 return padding || '';