@@ -3433,9 +3433,9 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
3433
3433
raise ValueError (f"na_action={ na_action } not supported" )
3434
3434
3435
3435
# TODO(shobs): Support **kwargs
3436
- # Reproject as workaround to applying filter too late. This forces the filter
3437
- # to be applied before passing data to remote function, protecting from bad
3438
- # inputs causing errors.
3436
+ # Reproject as workaround to applying filter too late. This forces the
3437
+ # filter to be applied before passing data to remote function,
3438
+ # protecting from bad inputs causing errors.
3439
3439
reprojected_df = DataFrame (self ._block ._force_reproject ())
3440
3440
return reprojected_df ._apply_unary_op (
3441
3441
ops .RemoteFunctionOp (func = func , apply_on_null = (na_action is None ))
@@ -3448,65 +3448,99 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
3448
3448
category = bigframes .exceptions .PreviewWarning ,
3449
3449
)
3450
3450
3451
- # Early check whether the dataframe dtypes are currently supported
3452
- # in the remote function
3453
- # NOTE: Keep in sync with the value converters used in the gcf code
3454
- # generated in remote_function_template.py
3455
- remote_function_supported_dtypes = (
3456
- bigframes .dtypes .INT_DTYPE ,
3457
- bigframes .dtypes .FLOAT_DTYPE ,
3458
- bigframes .dtypes .BOOL_DTYPE ,
3459
- bigframes .dtypes .BYTES_DTYPE ,
3460
- bigframes .dtypes .STRING_DTYPE ,
3461
- )
3462
- supported_dtypes_types = tuple (
3463
- type (dtype )
3464
- for dtype in remote_function_supported_dtypes
3465
- if not isinstance (dtype , pandas .ArrowDtype )
3466
- )
3467
- # Check ArrowDtype separately since multiple BigQuery types map to
3468
- # ArrowDtype, including BYTES and TIMESTAMP.
3469
- supported_arrow_types = tuple (
3470
- dtype .pyarrow_dtype
3471
- for dtype in remote_function_supported_dtypes
3472
- if isinstance (dtype , pandas .ArrowDtype )
3473
- )
3474
- supported_dtypes_hints = tuple (
3475
- str (dtype ) for dtype in remote_function_supported_dtypes
3476
- )
3477
-
3478
- for dtype in self .dtypes :
3479
- if (
3480
- # Not one of the pandas/numpy types.
3481
- not isinstance (dtype , supported_dtypes_types )
3482
- # And not one of the arrow types.
3483
- and not (
3484
- isinstance (dtype , pandas .ArrowDtype )
3485
- and any (
3486
- dtype .pyarrow_dtype .equals (arrow_type )
3487
- for arrow_type in supported_arrow_types
3488
- )
3489
- )
3490
- ):
3491
- raise NotImplementedError (
3492
- f"DataFrame has a column of dtype '{ dtype } ' which is not supported with axis=1."
3493
- f" Supported dtypes are { supported_dtypes_hints } ."
3494
- )
3495
-
3496
3451
# Check if the function is a remote function
3497
3452
if not hasattr (func , "bigframes_remote_function" ):
3498
3453
raise ValueError ("For axis=1 a remote function must be used." )
3499
3454
3500
- # Serialize the rows as json values
3501
- block = self ._get_block ()
3502
- rows_as_json_series = bigframes .series .Series (
3503
- block ._get_rows_as_json_values ()
3504
- )
3455
+ is_row_processor = getattr (func , "is_row_processor" )
3456
+ if is_row_processor :
3457
+ # Early check whether the dataframe dtypes are currently supported
3458
+ # in the remote function
3459
+ # NOTE: Keep in sync with the value converters used in the gcf code
3460
+ # generated in remote_function_template.py
3461
+ remote_function_supported_dtypes = (
3462
+ bigframes .dtypes .INT_DTYPE ,
3463
+ bigframes .dtypes .FLOAT_DTYPE ,
3464
+ bigframes .dtypes .BOOL_DTYPE ,
3465
+ bigframes .dtypes .BYTES_DTYPE ,
3466
+ bigframes .dtypes .STRING_DTYPE ,
3467
+ )
3468
+ supported_dtypes_types = tuple (
3469
+ type (dtype )
3470
+ for dtype in remote_function_supported_dtypes
3471
+ if not isinstance (dtype , pandas .ArrowDtype )
3472
+ )
3473
+ # Check ArrowDtype separately since multiple BigQuery types map to
3474
+ # ArrowDtype, including BYTES and TIMESTAMP.
3475
+ supported_arrow_types = tuple (
3476
+ dtype .pyarrow_dtype
3477
+ for dtype in remote_function_supported_dtypes
3478
+ if isinstance (dtype , pandas .ArrowDtype )
3479
+ )
3480
+ supported_dtypes_hints = tuple (
3481
+ str (dtype ) for dtype in remote_function_supported_dtypes
3482
+ )
3505
3483
3506
- # Apply the function
3507
- result_series = rows_as_json_series ._apply_unary_op (
3508
- ops .RemoteFunctionOp (func = func , apply_on_null = True )
3509
- )
3484
+ for dtype in self .dtypes :
3485
+ if (
3486
+ # Not one of the pandas/numpy types.
3487
+ not isinstance (dtype , supported_dtypes_types )
3488
+ # And not one of the arrow types.
3489
+ and not (
3490
+ isinstance (dtype , pandas .ArrowDtype )
3491
+ and any (
3492
+ dtype .pyarrow_dtype .equals (arrow_type )
3493
+ for arrow_type in supported_arrow_types
3494
+ )
3495
+ )
3496
+ ):
3497
+ raise NotImplementedError (
3498
+ f"DataFrame has a column of dtype '{ dtype } ' which is not supported with axis=1."
3499
+ f" Supported dtypes are { supported_dtypes_hints } ."
3500
+ )
3501
+
3502
+ # Serialize the rows as json values
3503
+ block = self ._get_block ()
3504
+ rows_as_json_series = bigframes .series .Series (
3505
+ block ._get_rows_as_json_values ()
3506
+ )
3507
+
3508
+ # Apply the function
3509
+ result_series = rows_as_json_series ._apply_unary_op (
3510
+ ops .RemoteFunctionOp (func = func , apply_on_null = True )
3511
+ )
3512
+ else :
3513
+ # This is a special case where we are providing not-pandas-like
3514
+ # extension. If the remote function can take one or more params
3515
+ # then we assume that here the user intention is to use the
3516
+ # column values of the dataframe as arguments to the function.
3517
+ # For this to work the following condition must be true:
3518
+ # 1. The number or input params in the function must be same
3519
+ # as the number of columns in the dataframe
3520
+ # 2. The dtypes of the columns in the dataframe must be
3521
+ # compatible with the data types of the input params
3522
+ # 3. The order of the columns in the dataframe must correspond
3523
+ # to the order of the input params in the function
3524
+ udf_input_dtypes = getattr (func , "input_dtypes" )
3525
+ if len (udf_input_dtypes ) != len (self .columns ):
3526
+ raise ValueError (
3527
+ f"Remote function takes { len (udf_input_dtypes )} arguments but DataFrame has { len (self .columns )} columns."
3528
+ )
3529
+ if udf_input_dtypes != tuple (self .dtypes .to_list ()):
3530
+ raise ValueError (
3531
+ f"Remote function takes arguments of types { udf_input_dtypes } but DataFrame dtypes are { tuple (self .dtypes )} ."
3532
+ )
3533
+
3534
+ series_list = [self [col ] for col in self .columns ]
3535
+ # Reproject as workaround to applying filter too late. This forces the
3536
+ # filter to be applied before passing data to remote function,
3537
+ # protecting from bad inputs causing errors.
3538
+ reprojected_series = bigframes .series .Series (
3539
+ series_list [0 ]._block ._force_reproject ()
3540
+ )
3541
+ result_series = reprojected_series ._apply_nary_op (
3542
+ ops .NaryRemoteFunctionOp (func = func ), series_list [1 :]
3543
+ )
3510
3544
result_series .name = None
3511
3545
3512
3546
# Return Series with materialized result so that any error in the remote
0 commit comments