Skip to content

Commit 0d24f73

Browse files
fix: Fewer relation joins from df self-operations (#823)
1 parent 27f8631 commit 0d24f73

File tree

6 files changed

+118
-83
lines changed

6 files changed

+118
-83
lines changed

bigframes/core/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,9 @@ def _cross_join_w_labels(
460460
conditions=(), mappings=(*labels_mappings, *table_mappings), type="cross"
461461
)
462462
if join_side == "left":
463-
joined_array = self.join(labels_array, join_def=join)
463+
joined_array = self.relational_join(labels_array, join_def=join)
464464
else:
465-
joined_array = labels_array.join(self, join_def=join)
465+
joined_array = labels_array.relational_join(self, join_def=join)
466466
return joined_array
467467

468468
def _create_unpivot_labels_array(
@@ -485,30 +485,27 @@ def _create_unpivot_labels_array(
485485

486486
return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session)
487487

488-
def join(
488+
def relational_join(
489489
self,
490490
other: ArrayValue,
491491
join_def: join_def.JoinDefinition,
492-
allow_row_identity_join: bool = False,
493-
):
492+
) -> ArrayValue:
494493
join_node = nodes.JoinNode(
495494
left_child=self.node,
496495
right_child=other.node,
497496
join=join_def,
498-
allow_row_identity_join=allow_row_identity_join,
499497
)
500-
if allow_row_identity_join:
501-
return ArrayValue(bigframes.core.rewrite.maybe_rewrite_join(join_node))
502498
return ArrayValue(join_node)
503499

504500
def try_align_as_projection(
505501
self,
506502
other: ArrayValue,
507503
join_type: join_def.JoinType,
504+
join_keys: typing.Tuple[join_def.CoalescedColumnMapping, ...],
508505
mappings: typing.Tuple[join_def.JoinColumnMapping, ...],
509506
) -> typing.Optional[ArrayValue]:
510507
result = bigframes.core.rewrite.join_as_projection(
511-
self.node, other.node, mappings, join_type
508+
self.node, other.node, join_keys, mappings, join_type
512509
)
513510
if result is not None:
514511
return ArrayValue(result)

bigframes/core/blocks.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,7 +2008,7 @@ def merge(
20082008
mappings=(*left_mappings, *right_mappings),
20092009
type=how,
20102010
)
2011-
joined_expr = self.expr.join(other.expr, join_def=join_def)
2011+
joined_expr = self.expr.relational_join(other.expr, join_def=join_def)
20122012
result_columns = []
20132013
matching_join_labels = []
20142014

@@ -2267,25 +2267,33 @@ def join(
22672267
raise NotImplementedError(
22682268
f"Only how='outer','left','right','inner' currently supported. {constants.FEEDBACK_LINK}"
22692269
)
2270-
# Special case for null index,
2270+
# Handle null index, which only supports row join
2271+
if (self.index.nlevels == other.index.nlevels == 0) and not block_identity_join:
2272+
if not block_identity_join:
2273+
result = try_row_join(self, other, how=how)
2274+
if result is not None:
2275+
return result
2276+
raise bigframes.exceptions.NullIndexError(
2277+
"Cannot implicitly align objects. Set an explicit index using set_index."
2278+
)
2279+
2280+
# Oddly, pandas row-wise join ignores right index names
22712281
if (
2272-
(self.index.nlevels == other.index.nlevels == 0)
2273-
and not sort
2274-
and not block_identity_join
2282+
not block_identity_join
2283+
and (self.index.nlevels == other.index.nlevels)
2284+
and (self.index.dtypes == other.index.dtypes)
22752285
):
2276-
return join_indexless(self, other, how=how)
2286+
result = try_row_join(self, other, how=how)
2287+
if result is not None:
2288+
return result
22772289

22782290
self._throw_if_null_index("join")
22792291
other._throw_if_null_index("join")
22802292
if self.index.nlevels == other.index.nlevels == 1:
2281-
return join_mono_indexed(
2282-
self, other, how=how, sort=sort, block_identity_join=block_identity_join
2283-
)
2284-
else:
2293+
return join_mono_indexed(self, other, how=how, sort=sort)
2294+
else: # Handles cases where one or both sides are multi-indexed
22852295
# Always sort mult-index join
2286-
return join_multi_indexed(
2287-
self, other, how=how, sort=sort, block_identity_join=block_identity_join
2288-
)
2296+
return join_multi_indexed(self, other, how=how, sort=sort)
22892297

22902298
def _force_reproject(self) -> Block:
22912299
"""Forces a reprojection of the underlying tables expression. Used to force predicate/order application before subsequent operations."""
@@ -2623,46 +2631,55 @@ def is_uniquely_named(self: BlockIndexProperties):
26232631
return len(set(self.names)) == len(self.names)
26242632

26252633

2626-
def join_indexless(
2634+
def try_row_join(
26272635
left: Block,
26282636
right: Block,
26292637
*,
26302638
how="left",
2631-
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
2632-
"""Joins two blocks"""
2639+
) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]:
2640+
"""Joins two blocks that have a common root expression by merging the projections."""
26332641
left_expr = left.expr
26342642
right_expr = right.expr
2643+
# Create a new array value, mapping from both, then left, and then right
2644+
join_keys = tuple(
2645+
join_defs.CoalescedColumnMapping(
2646+
left_source_id=left_id,
2647+
right_source_id=right_id,
2648+
destination_id=guid.generate_guid(),
2649+
)
2650+
for left_id, right_id in zip(left.index_columns, right.index_columns)
2651+
)
26352652
left_mappings = [
26362653
join_defs.JoinColumnMapping(
26372654
source_table=join_defs.JoinSide.LEFT,
26382655
source_id=id,
26392656
destination_id=guid.generate_guid(),
26402657
)
2641-
for id in left_expr.column_ids
2658+
for id in left.value_columns
26422659
]
26432660
right_mappings = [
26442661
join_defs.JoinColumnMapping(
26452662
source_table=join_defs.JoinSide.RIGHT,
26462663
source_id=id,
26472664
destination_id=guid.generate_guid(),
26482665
)
2649-
for id in right_expr.column_ids
2666+
for id in right.value_columns
26502667
]
26512668
combined_expr = left_expr.try_align_as_projection(
26522669
right_expr,
26532670
join_type=how,
2671+
join_keys=join_keys,
26542672
mappings=(*left_mappings, *right_mappings),
26552673
)
26562674
if combined_expr is None:
2657-
raise bigframes.exceptions.NullIndexError(
2658-
"Cannot implicitly align objects. Set an explicit index using set_index."
2659-
)
2675+
return None
26602676
get_column_left = {m.source_id: m.destination_id for m in left_mappings}
26612677
get_column_right = {m.source_id: m.destination_id for m in right_mappings}
26622678
block = Block(
26632679
combined_expr,
26642680
column_labels=[*left.column_labels, *right.column_labels],
2665-
index_columns=(),
2681+
index_columns=(key.destination_id for key in join_keys),
2682+
index_labels=left.index.names,
26662683
)
26672684
return (
26682685
block,
@@ -2704,7 +2721,7 @@ def join_with_single_row(
27042721
mappings=(*left_mappings, *right_mappings),
27052722
type="cross",
27062723
)
2707-
combined_expr = left_expr.join(
2724+
combined_expr = left_expr.relational_join(
27082725
right_expr,
27092726
join_def=join_def,
27102727
)
@@ -2731,7 +2748,6 @@ def join_mono_indexed(
27312748
*,
27322749
how="left",
27332750
sort=False,
2734-
block_identity_join: bool = False,
27352751
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
27362752
left_expr = left.expr
27372753
right_expr = right.expr
@@ -2759,14 +2775,14 @@ def join_mono_indexed(
27592775
mappings=(*left_mappings, *right_mappings),
27602776
type=how,
27612777
)
2762-
combined_expr = left_expr.join(
2778+
2779+
combined_expr = left_expr.relational_join(
27632780
right_expr,
27642781
join_def=join_def,
2765-
allow_row_identity_join=(not block_identity_join),
27662782
)
2783+
27672784
get_column_left = join_def.get_left_mapping()
27682785
get_column_right = join_def.get_right_mapping()
2769-
# Drop original indices from each side. and used the coalesced combination generated by the join.
27702786
left_index = get_column_left[left.index_columns[0]]
27712787
right_index = get_column_right[right.index_columns[0]]
27722788
# Drop original indices from each side. and used the coalesced combination generated by the join.
@@ -2800,7 +2816,6 @@ def join_multi_indexed(
28002816
*,
28012817
how="left",
28022818
sort=False,
2803-
block_identity_join: bool = False,
28042819
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
28052820
if not (left.index.is_uniquely_named() and right.index.is_uniquely_named()):
28062821
raise ValueError("Joins not supported on indices with non-unique level names")
@@ -2819,8 +2834,6 @@ def join_multi_indexed(
28192834
left_join_ids = [left.index.resolve_level_exact(name) for name in common_names]
28202835
right_join_ids = [right.index.resolve_level_exact(name) for name in common_names]
28212836

2822-
names_fully_match = len(left_only_names) == 0 and len(right_only_names) == 0
2823-
28242837
left_expr = left.expr
28252838
right_expr = right.expr
28262839

@@ -2850,13 +2863,11 @@ def join_multi_indexed(
28502863
type=how,
28512864
)
28522865

2853-
combined_expr = left_expr.join(
2866+
combined_expr = left_expr.relational_join(
28542867
right_expr,
28552868
join_def=join_def,
2856-
# If we're only joining on a subset of the index columns, we need to
2857-
# perform a true join.
2858-
allow_row_identity_join=(names_fully_match and not block_identity_join),
28592869
)
2870+
28602871
get_column_left = join_def.get_left_mapping()
28612872
get_column_right = join_def.get_right_mapping()
28622873
left_ids_post_join = [get_column_left[id] for id in left_join_ids]

bigframes/core/join_def.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ class JoinColumnMapping:
4343
destination_id: str
4444

4545

46+
@dataclasses.dataclass(frozen=True)
47+
class CoalescedColumnMapping:
48+
"""Special column mapping used only by implicit joiner only"""
49+
50+
left_source_id: str
51+
right_source_id: str
52+
destination_id: str
53+
54+
4655
@dataclasses.dataclass(frozen=True)
4756
class JoinDefinition:
4857
conditions: Tuple[JoinCondition, ...]

bigframes/core/nodes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ class JoinNode(BigFrameNode):
183183
left_child: BigFrameNode
184184
right_child: BigFrameNode
185185
join: JoinDefinition
186-
allow_row_identity_join: bool = False
187186

188187
@property
189188
def row_preserving(self) -> bool:

bigframes/core/rewrite.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -106,28 +106,33 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]):
106106
)
107107

108108
def can_merge(
109-
self, right: SquashedSelect, join_def: join_defs.JoinDefinition
109+
self,
110+
right: SquashedSelect,
111+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
110112
) -> bool:
111113
"""Determines whether the two selections can be merged into a single selection."""
112-
if join_def.type == "cross":
113-
# Cannot convert cross join to projection
114-
return False
115-
116114
r_exprs_by_id = {id: expr for expr, id in right.columns}
117115
l_exprs_by_id = {id: expr for expr, id in self.columns}
118-
l_join_exprs = [l_exprs_by_id[cond.left_id] for cond in join_def.conditions]
119-
r_join_exprs = [r_exprs_by_id[cond.right_id] for cond in join_def.conditions]
116+
l_join_exprs = [
117+
l_exprs_by_id[join_key.left_source_id] for join_key in join_keys
118+
]
119+
r_join_exprs = [
120+
r_exprs_by_id[join_key.right_source_id] for join_key in join_keys
121+
]
120122

121-
if (self.root != right.root) or any(
122-
l_expr != r_expr for l_expr, r_expr in zip(l_join_exprs, r_join_exprs)
123-
):
123+
if self.root != right.root:
124+
return False
125+
if len(l_join_exprs) != len(r_join_exprs):
126+
return False
127+
if any(l_expr != r_expr for l_expr, r_expr in zip(l_join_exprs, r_join_exprs)):
124128
return False
125129
return True
126130

127131
def merge(
128132
self,
129133
right: SquashedSelect,
130134
join_type: join_defs.JoinType,
135+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
131136
mappings: Tuple[join_defs.JoinColumnMapping, ...],
132137
) -> SquashedSelect:
133138
if self.root != right.root:
@@ -147,11 +152,9 @@ def merge(
147152
l_relative, r_relative = relative_predicates(self.predicate, right.predicate)
148153
lmask = l_relative if join_type in {"right", "outer"} else None
149154
rmask = r_relative if join_type in {"left", "outer"} else None
150-
if lmask is not None:
151-
lselection = tuple((apply_mask(expr, lmask), id) for expr, id in lselection)
152-
if rmask is not None:
153-
rselection = tuple((apply_mask(expr, rmask), id) for expr, id in rselection)
154-
new_columns = remap_names(mappings, lselection, rselection)
155+
new_columns = merge_expressions(
156+
join_keys, mappings, lselection, rselection, lmask, rmask
157+
)
155158

156159
# Reconstruct ordering
157160
reverse_root = self.reverse_root
@@ -204,34 +207,21 @@ def expand(self) -> nodes.BigFrameNode:
204207
return nodes.ProjectionNode(child=root, assignments=self.columns)
205208

206209

207-
def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode:
208-
rewrite_common_node = common_selection_root(
209-
join_node.left_child, join_node.right_child
210-
)
211-
if rewrite_common_node is None:
212-
return join_node
213-
left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node)
214-
right_side = SquashedSelect.from_node_span(
215-
join_node.right_child, rewrite_common_node
216-
)
217-
if left_side.can_merge(right_side, join_node.join):
218-
return left_side.merge(
219-
right_side, join_node.join.type, join_node.join.mappings
220-
).expand()
221-
return join_node
222-
223-
224210
def join_as_projection(
225211
l_node: nodes.BigFrameNode,
226212
r_node: nodes.BigFrameNode,
213+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
227214
mappings: Tuple[join_defs.JoinColumnMapping, ...],
228215
how: join_defs.JoinType,
229216
) -> Optional[nodes.BigFrameNode]:
230217
rewrite_common_node = common_selection_root(l_node, r_node)
231218
if rewrite_common_node is not None:
232219
left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node)
233220
right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node)
234-
merged = left_side.merge(right_side, how, mappings)
221+
if not left_side.can_merge(right_side, join_keys):
222+
# Most likely because join keys didn't match
223+
return None
224+
merged = left_side.merge(right_side, how, join_keys, mappings)
235225
assert (
236226
merged is not None
237227
), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at [email protected]."
@@ -240,21 +230,33 @@ def join_as_projection(
240230
return None
241231

242232

243-
def remap_names(
233+
def merge_expressions(
234+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
244235
mappings: Tuple[join_defs.JoinColumnMapping, ...],
245236
lselection: Selection,
246237
rselection: Selection,
238+
lmask: Optional[scalar_exprs.Expression],
239+
rmask: Optional[scalar_exprs.Expression],
247240
) -> Selection:
248241
new_selection: Selection = tuple()
249242
l_exprs_by_id = {id: expr for expr, id in lselection}
250243
r_exprs_by_id = {id: expr for expr, id in rselection}
244+
for key in join_keys:
245+
# Join keys expressions are equivalent on both sides, so can choose either left or right key
246+
assert l_exprs_by_id[key.left_source_id] == r_exprs_by_id[key.right_source_id]
247+
expr = l_exprs_by_id[key.left_source_id]
248+
id = key.destination_id
249+
new_selection = (*new_selection, (expr, id))
251250
for mapping in mappings:
252251
if mapping.source_table == join_defs.JoinSide.LEFT:
253252
expr = l_exprs_by_id[mapping.source_id]
253+
if lmask is not None:
254+
expr = apply_mask(expr, lmask)
254255
else: # Right
255256
expr = r_exprs_by_id[mapping.source_id]
256-
id = mapping.destination_id
257-
new_selection = (*new_selection, (expr, id))
257+
if rmask is not None:
258+
expr = apply_mask(expr, rmask)
259+
new_selection = (*new_selection, (expr, mapping.destination_id))
258260
return new_selection
259261

260262

0 commit comments

Comments
 (0)