Skip to content

Commit 2158818

Browse files
authored
feat: df.apply(axis=1) to support remote function with mutiple params (#851)
* feat: extend `df.apply(axis=1)` to support remote function with mutiple params * add doctest, make small test remote function sticky * handle single param non-row-processing functions * reword the documentation a bit * handle missing input dtype in read_gbq_function * restore input types as tuple in read_gbq_function * clear previous remote function attributes * reword documentation for clarity * add/update comments to explain force reproject * make doctest example remote function with 3 params
1 parent 9959fc8 commit 2158818

File tree

10 files changed

+417
-106
lines changed

10 files changed

+417
-106
lines changed

bigframes/core/compile/scalar_op_compiler.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -191,19 +191,27 @@ def normalized_impl(args: typing.Sequence[ibis_types.Value], op: ops.RowOp):
191191

192192
return decorator
193193

194-
def register_nary_op(self, op_ref: typing.Union[ops.NaryOp, type[ops.NaryOp]]):
194+
def register_nary_op(
195+
self, op_ref: typing.Union[ops.NaryOp, type[ops.NaryOp]], pass_op: bool = False
196+
):
195197
"""
196198
Decorator to register a nary op implementation.
197199
198200
Args:
199201
op_ref (NaryOp or NaryOp type):
200202
Class or instance of operator that is implemented by the decorated function.
203+
pass_op (bool):
204+
Set to true if implementation takes the operator object as the last argument.
205+
This is needed for parameterized ops where parameters are part of op object.
201206
"""
202207
key = typing.cast(str, op_ref.name)
203208

204209
def decorator(impl: typing.Callable[..., ibis_types.Value]):
205210
def normalized_impl(args: typing.Sequence[ibis_types.Value], op: ops.RowOp):
206-
return impl(*args)
211+
if pass_op:
212+
return impl(*args, op=op)
213+
else:
214+
return impl(*args)
207215

208216
self._register(key, normalized_impl)
209217
return impl
@@ -1468,6 +1476,7 @@ def clip_op(
14681476
)
14691477

14701478

1479+
# N-ary Operations
14711480
@scalar_op_compiler.register_nary_op(ops.case_when_op)
14721481
def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value:
14731482
# ibis can handle most type coercions, but we need to force bool -> int
@@ -1487,6 +1496,19 @@ def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value:
14871496
return case_val.end()
14881497

14891498

1499+
@scalar_op_compiler.register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
1500+
def nary_remote_function_op_impl(
1501+
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp
1502+
):
1503+
ibis_node = getattr(op.func, "ibis_node", None)
1504+
if ibis_node is None:
1505+
raise TypeError(
1506+
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
1507+
)
1508+
result = ibis_node(*operands)
1509+
return result
1510+
1511+
14901512
# Helpers
14911513
def is_null(value) -> bool:
14921514
# float NaN/inf should be treated as distinct from 'true' null values

bigframes/dataframe.py

+91-57
Original file line numberDiff line numberDiff line change
@@ -3433,9 +3433,9 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
34333433
raise ValueError(f"na_action={na_action} not supported")
34343434

34353435
# TODO(shobs): Support **kwargs
3436-
# Reproject as workaround to applying filter too late. This forces the filter
3437-
# to be applied before passing data to remote function, protecting from bad
3438-
# inputs causing errors.
3436+
# Reproject as workaround to applying filter too late. This forces the
3437+
# filter to be applied before passing data to remote function,
3438+
# protecting from bad inputs causing errors.
34393439
reprojected_df = DataFrame(self._block._force_reproject())
34403440
return reprojected_df._apply_unary_op(
34413441
ops.RemoteFunctionOp(func=func, apply_on_null=(na_action is None))
@@ -3448,65 +3448,99 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
34483448
category=bigframes.exceptions.PreviewWarning,
34493449
)
34503450

3451-
# Early check whether the dataframe dtypes are currently supported
3452-
# in the remote function
3453-
# NOTE: Keep in sync with the value converters used in the gcf code
3454-
# generated in remote_function_template.py
3455-
remote_function_supported_dtypes = (
3456-
bigframes.dtypes.INT_DTYPE,
3457-
bigframes.dtypes.FLOAT_DTYPE,
3458-
bigframes.dtypes.BOOL_DTYPE,
3459-
bigframes.dtypes.BYTES_DTYPE,
3460-
bigframes.dtypes.STRING_DTYPE,
3461-
)
3462-
supported_dtypes_types = tuple(
3463-
type(dtype)
3464-
for dtype in remote_function_supported_dtypes
3465-
if not isinstance(dtype, pandas.ArrowDtype)
3466-
)
3467-
# Check ArrowDtype separately since multiple BigQuery types map to
3468-
# ArrowDtype, including BYTES and TIMESTAMP.
3469-
supported_arrow_types = tuple(
3470-
dtype.pyarrow_dtype
3471-
for dtype in remote_function_supported_dtypes
3472-
if isinstance(dtype, pandas.ArrowDtype)
3473-
)
3474-
supported_dtypes_hints = tuple(
3475-
str(dtype) for dtype in remote_function_supported_dtypes
3476-
)
3477-
3478-
for dtype in self.dtypes:
3479-
if (
3480-
# Not one of the pandas/numpy types.
3481-
not isinstance(dtype, supported_dtypes_types)
3482-
# And not one of the arrow types.
3483-
and not (
3484-
isinstance(dtype, pandas.ArrowDtype)
3485-
and any(
3486-
dtype.pyarrow_dtype.equals(arrow_type)
3487-
for arrow_type in supported_arrow_types
3488-
)
3489-
)
3490-
):
3491-
raise NotImplementedError(
3492-
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1."
3493-
f" Supported dtypes are {supported_dtypes_hints}."
3494-
)
3495-
34963451
# Check if the function is a remote function
34973452
if not hasattr(func, "bigframes_remote_function"):
34983453
raise ValueError("For axis=1 a remote function must be used.")
34993454

3500-
# Serialize the rows as json values
3501-
block = self._get_block()
3502-
rows_as_json_series = bigframes.series.Series(
3503-
block._get_rows_as_json_values()
3504-
)
3455+
is_row_processor = getattr(func, "is_row_processor")
3456+
if is_row_processor:
3457+
# Early check whether the dataframe dtypes are currently supported
3458+
# in the remote function
3459+
# NOTE: Keep in sync with the value converters used in the gcf code
3460+
# generated in remote_function_template.py
3461+
remote_function_supported_dtypes = (
3462+
bigframes.dtypes.INT_DTYPE,
3463+
bigframes.dtypes.FLOAT_DTYPE,
3464+
bigframes.dtypes.BOOL_DTYPE,
3465+
bigframes.dtypes.BYTES_DTYPE,
3466+
bigframes.dtypes.STRING_DTYPE,
3467+
)
3468+
supported_dtypes_types = tuple(
3469+
type(dtype)
3470+
for dtype in remote_function_supported_dtypes
3471+
if not isinstance(dtype, pandas.ArrowDtype)
3472+
)
3473+
# Check ArrowDtype separately since multiple BigQuery types map to
3474+
# ArrowDtype, including BYTES and TIMESTAMP.
3475+
supported_arrow_types = tuple(
3476+
dtype.pyarrow_dtype
3477+
for dtype in remote_function_supported_dtypes
3478+
if isinstance(dtype, pandas.ArrowDtype)
3479+
)
3480+
supported_dtypes_hints = tuple(
3481+
str(dtype) for dtype in remote_function_supported_dtypes
3482+
)
35053483

3506-
# Apply the function
3507-
result_series = rows_as_json_series._apply_unary_op(
3508-
ops.RemoteFunctionOp(func=func, apply_on_null=True)
3509-
)
3484+
for dtype in self.dtypes:
3485+
if (
3486+
# Not one of the pandas/numpy types.
3487+
not isinstance(dtype, supported_dtypes_types)
3488+
# And not one of the arrow types.
3489+
and not (
3490+
isinstance(dtype, pandas.ArrowDtype)
3491+
and any(
3492+
dtype.pyarrow_dtype.equals(arrow_type)
3493+
for arrow_type in supported_arrow_types
3494+
)
3495+
)
3496+
):
3497+
raise NotImplementedError(
3498+
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1."
3499+
f" Supported dtypes are {supported_dtypes_hints}."
3500+
)
3501+
3502+
# Serialize the rows as json values
3503+
block = self._get_block()
3504+
rows_as_json_series = bigframes.series.Series(
3505+
block._get_rows_as_json_values()
3506+
)
3507+
3508+
# Apply the function
3509+
result_series = rows_as_json_series._apply_unary_op(
3510+
ops.RemoteFunctionOp(func=func, apply_on_null=True)
3511+
)
3512+
else:
3513+
# This is a special case where we are providing not-pandas-like
3514+
# extension. If the remote function can take one or more params
3515+
# then we assume that here the user intention is to use the
3516+
# column values of the dataframe as arguments to the function.
3517+
# For this to work the following condition must be true:
3518+
# 1. The number or input params in the function must be same
3519+
# as the number of columns in the dataframe
3520+
# 2. The dtypes of the columns in the dataframe must be
3521+
# compatible with the data types of the input params
3522+
# 3. The order of the columns in the dataframe must correspond
3523+
# to the order of the input params in the function
3524+
udf_input_dtypes = getattr(func, "input_dtypes")
3525+
if len(udf_input_dtypes) != len(self.columns):
3526+
raise ValueError(
3527+
f"Remote function takes {len(udf_input_dtypes)} arguments but DataFrame has {len(self.columns)} columns."
3528+
)
3529+
if udf_input_dtypes != tuple(self.dtypes.to_list()):
3530+
raise ValueError(
3531+
f"Remote function takes arguments of types {udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
3532+
)
3533+
3534+
series_list = [self[col] for col in self.columns]
3535+
# Reproject as workaround to applying filter too late. This forces the
3536+
# filter to be applied before passing data to remote function,
3537+
# protecting from bad inputs causing errors.
3538+
reprojected_series = bigframes.series.Series(
3539+
series_list[0]._block._force_reproject()
3540+
)
3541+
result_series = reprojected_series._apply_nary_op(
3542+
ops.NaryRemoteFunctionOp(func=func), series_list[1:]
3543+
)
35103544
result_series.name = None
35113545

35123546
# Return Series with materialized result so that any error in the remote

bigframes/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,7 @@ class QueryComplexityError(RuntimeError):
5757

5858
class TimeTravelDisabledWarning(Warning):
5959
"""A query was reattempted without time travel."""
60+
61+
62+
class UnknownDataTypeWarning(Warning):
63+
"""Data type is unknown."""

bigframes/functions/remote_function.py

+37-3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from bigframes import clients
6767
import bigframes.constants as constants
6868
import bigframes.core.compile.ibis_types
69+
import bigframes.dtypes
6970
import bigframes.functions.remote_function_template
7071

7172
logger = logging.getLogger(__name__)
@@ -895,8 +896,8 @@ def remote_function(
895896
reuse (bool, Optional):
896897
Reuse the remote function if already exists.
897898
`True` by default, which will result in reusing an existing remote
898-
function and corresponding cloud function (if any) that was
899-
previously created for the same udf.
899+
function and corresponding cloud function that was previously
900+
created (if any) for the same udf.
900901
Please note that for an unnamed (i.e. created without an explicit
901902
`name` argument) remote function, the BigQuery DataFrames
902903
session id is attached in the cloud artifacts names. So for the
@@ -1174,7 +1175,9 @@ def try_delattr(attr):
11741175

11751176
try_delattr("bigframes_cloud_function")
11761177
try_delattr("bigframes_remote_function")
1178+
try_delattr("input_dtypes")
11771179
try_delattr("output_dtype")
1180+
try_delattr("is_row_processor")
11781181
try_delattr("ibis_node")
11791182

11801183
(
@@ -1216,12 +1219,20 @@ def try_delattr(attr):
12161219
rf_name
12171220
)
12181221
)
1219-
1222+
func.input_dtypes = tuple(
1223+
[
1224+
bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype(
1225+
input_type
1226+
)
1227+
for input_type in ibis_signature.input_types
1228+
]
1229+
)
12201230
func.output_dtype = (
12211231
bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype(
12221232
ibis_signature.output_type
12231233
)
12241234
)
1235+
func.is_row_processor = is_row_processor
12251236
func.ibis_node = node
12261237

12271238
# If a new remote function was created, update the cloud artifacts
@@ -1305,6 +1316,29 @@ def func(*ignored_args, **ignored_kwargs):
13051316
signature=(ibis_signature.input_types, ibis_signature.output_type),
13061317
)
13071318
func.bigframes_remote_function = str(routine_ref) # type: ignore
1319+
1320+
# set input bigframes data types
1321+
has_unknown_dtypes = False
1322+
function_input_dtypes = []
1323+
for ibis_type in ibis_signature.input_types:
1324+
input_dtype = cast(bigframes.dtypes.Dtype, bigframes.dtypes.DEFAULT_DTYPE)
1325+
if ibis_type is None:
1326+
has_unknown_dtypes = True
1327+
else:
1328+
input_dtype = (
1329+
bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype(
1330+
ibis_type
1331+
)
1332+
)
1333+
function_input_dtypes.append(input_dtype)
1334+
if has_unknown_dtypes:
1335+
warnings.warn(
1336+
"The function has one or more missing input data types."
1337+
f" BigQuery DataFrames will assume default data type {bigframes.dtypes.DEFAULT_DTYPE} for them.",
1338+
category=bigframes.exceptions.UnknownDataTypeWarning,
1339+
)
1340+
func.input_dtypes = tuple(function_input_dtypes) # type: ignore
1341+
13081342
func.output_dtype = bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype( # type: ignore
13091343
ibis_signature.output_type
13101344
)

bigframes/operations/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,19 @@ def output_type(self, *input_types):
659659
raise AttributeError("output_dtype not defined")
660660

661661

662+
@dataclasses.dataclass(frozen=True)
663+
class NaryRemoteFunctionOp(NaryOp):
664+
name: typing.ClassVar[str] = "nary_remote_function"
665+
func: typing.Callable
666+
667+
def output_type(self, *input_types):
668+
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
669+
if hasattr(self.func, "output_dtype"):
670+
return self.func.output_dtype
671+
else:
672+
raise AttributeError("output_dtype not defined")
673+
674+
662675
add_op = AddOp()
663676
sub_op = SubOp()
664677
mul_op = create_binary_op(name="mul", type_signature=op_typing.BINARY_NUMERIC)

bigframes/series.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1442,9 +1442,6 @@ def apply(
14421442
) -> Series:
14431443
# TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs
14441444
# is actually a ternary op
1445-
# Reproject as workaround to applying filter too late. This forces the filter
1446-
# to be applied before passing data to remote function, protecting from bad
1447-
# inputs causing errors.
14481445

14491446
if by_row not in ["compat", False]:
14501447
raise ValueError("Param by_row must be one of 'compat' or False")
@@ -1474,7 +1471,10 @@ def apply(
14741471
ex.message += f"\n{_remote_function_recommendation_message}"
14751472
raise
14761473

1477-
# We are working with remote function at this point
1474+
# We are working with remote function at this point.
1475+
# Reproject as workaround to applying filter too late. This forces the
1476+
# filter to be applied before passing data to remote function,
1477+
# protecting from bad inputs causing errors.
14781478
reprojected_series = Series(self._block._force_reproject())
14791479
result_series = reprojected_series._apply_unary_op(
14801480
ops.RemoteFunctionOp(func=func, apply_on_null=True)
@@ -1507,6 +1507,9 @@ def combine(
15071507
ex.message += f"\n{_remote_function_recommendation_message}"
15081508
raise
15091509

1510+
# Reproject as workaround to applying filter too late. This forces the
1511+
# filter to be applied before passing data to remote function,
1512+
# protecting from bad inputs causing errors.
15101513
reprojected_series = Series(self._block._force_reproject())
15111514
result_series = reprojected_series._apply_binary_op(
15121515
other, ops.BinaryRemoteFunctionOp(func=func)

bigframes/session/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1661,8 +1661,8 @@ def remote_function(
16611661
reuse (bool, Optional):
16621662
Reuse the remote function if already exists.
16631663
`True` by default, which will result in reusing an existing remote
1664-
function and corresponding cloud function (if any) that was
1665-
previously created for the same udf.
1664+
function and corresponding cloud function that was previously
1665+
created (if any) for the same udf.
16661666
Please note that for an unnamed (i.e. created without an explicit
16671667
`name` argument) remote function, the BigQuery DataFrames
16681668
session id is attached in the cloud artifacts names. So for the

0 commit comments

Comments
 (0)