1 import {SerializedTableNode, TableNode, TableRowNode} from "@lexical/table";
2 import {DOMConversion, DOMConversionMap, DOMConversionOutput, LexicalEditor, LexicalNode, Spread} from "lexical";
3 import {EditorConfig} from "lexical/LexicalEditor";
4 import {el} from "../helpers";
6 export type SerializedCustomTableNode = Spread<{
9 }, SerializedTableNode>
11 export class CustomTableNode extends TableNode {
13 __colWidths: string[] = [];
16 return 'custom-table';
20 const self = this.getWritable();
25 const self = this.getLatest();
29 setColWidths(widths: string[]) {
30 const self = this.getWritable();
31 self.__colWidths = widths;
34 getColWidths(): string[] {
35 const self = this.getLatest();
36 return self.__colWidths;
39 static clone(node: CustomTableNode) {
40 const newNode = new CustomTableNode(node.__key);
41 newNode.__id = node.__id;
42 newNode.__colWidths = node.__colWidths;
46 createDOM(config: EditorConfig): HTMLElement {
47 const dom = super.createDOM(config);
48 const id = this.getId();
50 dom.setAttribute('id', id);
53 const colWidths = this.getColWidths();
54 if (colWidths.length > 0) {
55 const colgroup = el('colgroup');
56 for (const width of colWidths) {
57 const col = el('col');
59 col.style.width = width;
69 updateDOM(): boolean {
73 exportJSON(): SerializedCustomTableNode {
75 ...super.exportJSON(),
79 colWidths: this.__colWidths,
83 static importJSON(serializedNode: SerializedCustomTableNode): CustomTableNode {
84 const node = $createCustomTableNode();
85 node.setId(serializedNode.id);
86 node.setColWidths(serializedNode.colWidths);
90 static importDOM(): DOMConversionMap|null {
92 table(node: HTMLElement): DOMConversion|null {
94 conversion: (element: HTMLElement): DOMConversionOutput|null => {
95 const node = $createCustomTableNode();
98 node.setId(element.id);
101 const colWidths = getTableColumnWidths(element as HTMLTableElement);
102 node.setColWidths(colWidths);
113 function getTableColumnWidths(table: HTMLTableElement): string[] {
114 const maxColRow = getMaxColRowFromTable(table);
116 const colGroup = table.querySelector('colgroup');
117 let widths: string[] = [];
118 if (colGroup && (colGroup.childElementCount === maxColRow?.childElementCount || !maxColRow)) {
119 widths = extractWidthsFromRow(colGroup);
121 if (widths.filter(Boolean).length === 0 && maxColRow) {
122 widths = extractWidthsFromRow(maxColRow);
128 function getMaxColRowFromTable(table: HTMLTableElement): HTMLTableRowElement|null {
129 const rows = table.querySelectorAll('tr');
130 let maxColCount: number = 0;
131 let maxColRow: HTMLTableRowElement|null = null;
133 for (const row of rows) {
134 if (row.childElementCount > maxColCount) {
136 maxColCount = row.childElementCount;
143 function extractWidthsFromRow(row: HTMLTableRowElement|HTMLTableColElement) {
144 return [...row.children].map(child => extractWidthFromElement(child as HTMLElement))
147 function extractWidthFromElement(element: HTMLElement): string {
148 let width = element.style.width || element.getAttribute('width');
149 if (width && !Number.isNaN(Number(width))) {
150 width = width + 'px';
156 export function $createCustomTableNode(): CustomTableNode {
157 return new CustomTableNode();
160 export function $isCustomTableNode(node: LexicalNode | null | undefined): boolean {
161 return node instanceof CustomTableNode;
164 export function $setTableColumnWidth(node: CustomTableNode, columnIndex: number, width: number): void {
165 const rows = node.getChildren() as TableRowNode[];
167 for (const row of rows) {
168 const cellCount = row.getChildren().length;
169 if (cellCount > maxCols) {
174 let colWidths = node.getColWidths();
175 if (colWidths.length === 0 || colWidths.length < maxCols) {
176 colWidths = Array(maxCols).fill('');
179 if (columnIndex + 1 > colWidths.length) {
180 console.error(`Attempted to set table column width for column [${columnIndex}] but only ${colWidths.length} columns found`);
183 colWidths[columnIndex] = width + 'px';
184 node.setColWidths(colWidths);
187 export function $getTableColumnWidth(editor: LexicalEditor, node: CustomTableNode, columnIndex: number): number {
188 const colWidths = node.getColWidths();
189 if (colWidths.length > columnIndex && colWidths[columnIndex].endsWith('px')) {
190 return Number(colWidths[columnIndex].replace('px', ''));
193 // Otherwise, get from table element
194 const table = editor.getElementByKey(node.__key) as HTMLTableElement|null;
196 const maxColRow = getMaxColRowFromTable(table);
197 if (maxColRow && maxColRow.children.length > columnIndex) {
198 const cell = maxColRow.children[columnIndex];
199 return cell.clientWidth;