Skip to content

Commit daef4f0

Browse files
feat: Allow join-free alignment of analytic expressions (#1168)
* feat: Allow join-free alignment of analytic expressions * address pr comments * fix bugs in pull_up_selection * fix unit test and remove validations * fix test failures
1 parent 1c8d510 commit daef4f0

File tree

13 files changed

+747
-341
lines changed

13 files changed

+747
-341
lines changed

bigframes/core/__init__.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import pyarrow as pa
2727
import pyarrow.feather as pa_feather
2828

29-
import bigframes.core.compile
3029
import bigframes.core.expression as ex
3130
import bigframes.core.guid
3231
import bigframes.core.identifiers as ids
@@ -35,15 +34,13 @@
3534
import bigframes.core.nodes as nodes
3635
from bigframes.core.ordering import OrderingExpression
3736
import bigframes.core.ordering as orderings
38-
import bigframes.core.rewrite
3937
import bigframes.core.schema as schemata
4038
import bigframes.core.tree_properties
4139
import bigframes.core.utils
4240
from bigframes.core.window_spec import WindowSpec
4341
import bigframes.dtypes
4442
import bigframes.operations as ops
4543
import bigframes.operations.aggregations as agg_ops
46-
import bigframes.session._io.bigquery
4744

4845
if typing.TYPE_CHECKING:
4946
from bigframes.session import Session
@@ -199,6 +196,8 @@ def as_cached(
199196

200197
def _try_evaluate_local(self):
201198
"""Use only for unit testing paths - not fully featured. Will throw exception if fails."""
199+
import bigframes.core.compile
200+
202201
return bigframes.core.compile.test_only_try_evaluate(self.node)
203202

204203
def get_column_type(self, key: str) -> bigframes.dtypes.Dtype:
@@ -422,22 +421,7 @@ def relational_join(
422421
l_mapping = { # Identity mapping, only rename right side
423422
lcol.name: lcol.name for lcol in self.node.ids
424423
}
425-
r_mapping = { # Rename conflicting names
426-
rcol.name: rcol.name
427-
if (rcol.name not in l_mapping)
428-
else bigframes.core.guid.generate_guid()
429-
for rcol in other.node.ids
430-
}
431-
other_node = other.node
432-
if set(other_node.ids) & set(self.node.ids):
433-
other_node = nodes.SelectionNode(
434-
other_node,
435-
tuple(
436-
(ex.deref(old_id), ids.ColumnId(new_id))
437-
for old_id, new_id in r_mapping.items()
438-
),
439-
)
440-
424+
other_node, r_mapping = self.prepare_join_names(other)
441425
join_node = nodes.JoinNode(
442426
left_child=self.node,
443427
right_child=other_node,
@@ -449,14 +433,63 @@ def relational_join(
449433
)
450434
return ArrayValue(join_node), (l_mapping, r_mapping)
451435

452-
def try_align_as_projection(
436+
def try_row_join(
437+
self,
438+
other: ArrayValue,
439+
conditions: typing.Tuple[typing.Tuple[str, str], ...] = (),
440+
) -> Optional[
441+
typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]
442+
]:
443+
l_mapping = { # Identity mapping, only rename right side
444+
lcol.name: lcol.name for lcol in self.node.ids
445+
}
446+
other_node, r_mapping = self.prepare_join_names(other)
447+
import bigframes.core.rewrite
448+
449+
result_node = bigframes.core.rewrite.try_join_as_projection(
450+
self.node, other_node, conditions
451+
)
452+
if result_node is None:
453+
return None
454+
455+
return (
456+
ArrayValue(result_node),
457+
(l_mapping, r_mapping),
458+
)
459+
460+
def prepare_join_names(
461+
self, other: ArrayValue
462+
) -> Tuple[bigframes.core.nodes.BigFrameNode, dict[str, str]]:
463+
if set(other.node.ids) & set(self.node.ids):
464+
r_mapping = { # Rename conflicting names
465+
rcol.name: rcol.name
466+
if (rcol.name not in self.column_ids)
467+
else bigframes.core.guid.generate_guid()
468+
for rcol in other.node.ids
469+
}
470+
return (
471+
nodes.SelectionNode(
472+
other.node,
473+
tuple(
474+
(ex.deref(old_id), ids.ColumnId(new_id))
475+
for old_id, new_id in r_mapping.items()
476+
),
477+
),
478+
r_mapping,
479+
)
480+
else:
481+
return other.node, {id: id for id in other.column_ids}
482+
483+
def try_legacy_row_join(
453484
self,
454485
other: ArrayValue,
455486
join_type: join_def.JoinType,
456487
join_keys: typing.Tuple[join_def.CoalescedColumnMapping, ...],
457488
mappings: typing.Tuple[join_def.JoinColumnMapping, ...],
458489
) -> typing.Optional[ArrayValue]:
459-
result = bigframes.core.rewrite.join_as_projection(
490+
import bigframes.core.rewrite
491+
492+
result = bigframes.core.rewrite.legacy_join_as_projection(
460493
self.node, other.node, join_keys, mappings, join_type
461494
)
462495
if result is not None:
@@ -488,11 +521,4 @@ def _gen_namespaced_uid(self) -> str:
488521
return self._gen_namespaced_uids(1)[0]
489522

490523
def _gen_namespaced_uids(self, n: int) -> List[str]:
491-
i = len(self.node.defined_variables)
492-
genned_ids: List[str] = []
493-
while len(genned_ids) < n:
494-
attempted_id = f"col_{i}"
495-
if attempted_id not in self.node.defined_variables:
496-
genned_ids.append(attempted_id)
497-
i = i + 1
498-
return genned_ids
524+
return [ids.ColumnId.unique().name for _ in range(n)]

bigframes/core/blocks.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,7 +2341,9 @@ def join(
23412341
# Handle null index, which only supports row join
23422342
# This is the canonical way of aligning on null index, so always allow (ignore block_identity_join)
23432343
if self.index.nlevels == other.index.nlevels == 0:
2344-
result = try_row_join(self, other, how=how)
2344+
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
2345+
self, other
2346+
)
23452347
if result is not None:
23462348
return result
23472349
raise bigframes.exceptions.NullIndexError(
@@ -2354,7 +2356,9 @@ def join(
23542356
and (self.index.nlevels == other.index.nlevels)
23552357
and (self.index.dtypes == other.index.dtypes)
23562358
):
2357-
result = try_row_join(self, other, how=how)
2359+
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
2360+
self, other
2361+
)
23582362
if result is not None:
23592363
return result
23602364

@@ -2693,7 +2697,35 @@ def is_uniquely_named(self: BlockIndexProperties):
26932697
return len(set(self.names)) == len(self.names)
26942698

26952699

2696-
def try_row_join(
2700+
def try_new_row_join(
2701+
left: Block, right: Block
2702+
) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]:
2703+
join_keys = tuple(
2704+
(left_id, right_id)
2705+
for left_id, right_id in zip(left.index_columns, right.index_columns)
2706+
)
2707+
join_result = left.expr.try_row_join(right.expr, join_keys)
2708+
if join_result is None: # did not succeed
2709+
return None
2710+
combined_expr, (get_column_left, get_column_right) = join_result
2711+
# Keep the left index column, and drop the matching right column
2712+
index_cols_post_join = [get_column_left[id] for id in left.index_columns]
2713+
combined_expr = combined_expr.drop_columns(
2714+
[get_column_right[id] for id in right.index_columns]
2715+
)
2716+
block = Block(
2717+
combined_expr,
2718+
index_columns=index_cols_post_join,
2719+
column_labels=left.column_labels.append(right.column_labels),
2720+
index_labels=left.index.names,
2721+
)
2722+
return (
2723+
block,
2724+
(get_column_left, get_column_right),
2725+
)
2726+
2727+
2728+
def try_legacy_row_join(
26972729
left: Block,
26982730
right: Block,
26992731
*,
@@ -2727,7 +2759,7 @@ def try_row_join(
27272759
)
27282760
for id in right.value_columns
27292761
]
2730-
combined_expr = left_expr.try_align_as_projection(
2762+
combined_expr = left_expr.try_legacy_row_join(
27312763
right_expr,
27322764
join_type=how,
27332765
join_keys=join_keys,

bigframes/core/identifiers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import itertools
1919
from typing import Generator
2020

21+
import bigframes.core.guid
22+
2123

2224
def standard_id_strings(prefix: str = "col_") -> Generator[str, None, None]:
2325
i = 0
@@ -47,6 +49,10 @@ def local_normalized(self) -> ColumnId:
4749
def __lt__(self, other: ColumnId) -> bool:
4850
return self.sql < other.sql
4951

52+
@classmethod
53+
def unique(cls) -> ColumnId:
54+
return ColumnId(name=bigframes.core.guid.generate_guid())
55+
5056

5157
@dataclasses.dataclass(frozen=True)
5258
class SerialColumnId(ColumnId):

0 commit comments

Comments
 (0)