1 import {SerializedTableNode, TableNode, TableRowNode} from "@lexical/table";
2 import {DOMConversion, DOMConversionMap, DOMConversionOutput, LexicalEditor, LexicalNode, Spread} from "lexical";
3 import {EditorConfig} from "lexical/LexicalEditor";
5 import {el} from "../utils/dom";
7 export type SerializedCustomTableNode = Spread<{
10 }, SerializedTableNode>
12 export class CustomTableNode extends TableNode {
14 __colWidths: string[] = [];
17 return 'custom-table';
21 const self = this.getWritable();
26 const self = this.getLatest();
30 setColWidths(widths: string[]) {
31 const self = this.getWritable();
32 self.__colWidths = widths;
35 getColWidths(): string[] {
36 const self = this.getLatest();
37 return self.__colWidths;
40 static clone(node: CustomTableNode) {
41 const newNode = new CustomTableNode(node.__key);
42 newNode.__id = node.__id;
43 newNode.__colWidths = node.__colWidths;
47 createDOM(config: EditorConfig): HTMLElement {
48 const dom = super.createDOM(config);
49 const id = this.getId();
51 dom.setAttribute('id', id);
54 const colWidths = this.getColWidths();
55 if (colWidths.length > 0) {
56 const colgroup = el('colgroup');
57 for (const width of colWidths) {
58 const col = el('col');
60 col.style.width = width;
70 updateDOM(): boolean {
74 exportJSON(): SerializedCustomTableNode {
76 ...super.exportJSON(),
80 colWidths: this.__colWidths,
84 static importJSON(serializedNode: SerializedCustomTableNode): CustomTableNode {
85 const node = $createCustomTableNode();
86 node.setId(serializedNode.id);
87 node.setColWidths(serializedNode.colWidths);
91 static importDOM(): DOMConversionMap|null {
93 table(node: HTMLElement): DOMConversion|null {
95 conversion: (element: HTMLElement): DOMConversionOutput|null => {
96 const node = $createCustomTableNode();
99 node.setId(element.id);
102 const colWidths = getTableColumnWidths(element as HTMLTableElement);
103 node.setColWidths(colWidths);
114 function getTableColumnWidths(table: HTMLTableElement): string[] {
115 const maxColRow = getMaxColRowFromTable(table);
117 const colGroup = table.querySelector('colgroup');
118 let widths: string[] = [];
119 if (colGroup && (colGroup.childElementCount === maxColRow?.childElementCount || !maxColRow)) {
120 widths = extractWidthsFromRow(colGroup);
122 if (widths.filter(Boolean).length === 0 && maxColRow) {
123 widths = extractWidthsFromRow(maxColRow);
129 function getMaxColRowFromTable(table: HTMLTableElement): HTMLTableRowElement|null {
130 const rows = table.querySelectorAll('tr');
131 let maxColCount: number = 0;
132 let maxColRow: HTMLTableRowElement|null = null;
134 for (const row of rows) {
135 if (row.childElementCount > maxColCount) {
137 maxColCount = row.childElementCount;
144 function extractWidthsFromRow(row: HTMLTableRowElement|HTMLTableColElement) {
145 return [...row.children].map(child => extractWidthFromElement(child as HTMLElement))
148 function extractWidthFromElement(element: HTMLElement): string {
149 let width = element.style.width || element.getAttribute('width');
150 if (width && !Number.isNaN(Number(width))) {
151 width = width + 'px';
157 export function $createCustomTableNode(): CustomTableNode {
158 return new CustomTableNode();
161 export function $isCustomTableNode(node: LexicalNode | null | undefined): node is CustomTableNode {
162 return node instanceof CustomTableNode;
165 export function $setTableColumnWidth(node: CustomTableNode, columnIndex: number, width: number): void {
166 const rows = node.getChildren() as TableRowNode[];
168 for (const row of rows) {
169 const cellCount = row.getChildren().length;
170 if (cellCount > maxCols) {
175 let colWidths = node.getColWidths();
176 if (colWidths.length === 0 || colWidths.length < maxCols) {
177 colWidths = Array(maxCols).fill('');
180 if (columnIndex + 1 > colWidths.length) {
181 console.error(`Attempted to set table column width for column [${columnIndex}] but only ${colWidths.length} columns found`);
184 colWidths[columnIndex] = width + 'px';
185 node.setColWidths(colWidths);
188 export function $getTableColumnWidth(editor: LexicalEditor, node: CustomTableNode, columnIndex: number): number {
189 const colWidths = node.getColWidths();
190 if (colWidths.length > columnIndex && colWidths[columnIndex].endsWith('px')) {
191 return Number(colWidths[columnIndex].replace('px', ''));
194 // Otherwise, get from table element
195 const table = editor.getElementByKey(node.__key) as HTMLTableElement|null;
197 const maxColRow = getMaxColRowFromTable(table);
198 if (maxColRow && maxColRow.children.length > columnIndex) {
199 const cell = maxColRow.children[columnIndex];
200 return cell.clientWidth;