Skip to content
Merged
10 changes: 9 additions & 1 deletion doc/source/user_guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ different dtypes, then a common dtype will be determined in the same way as ``Da
Transformation
--------------

The ``transform`` method returns an object that is indexed the same (same size)
The ``transform`` method returns an object that is indexed the same
as the one being grouped. The transform function must:

* Return a result that is either the same size as the group chunk or
Expand All @@ -776,6 +776,14 @@ as the one being grouped. The transform function must:
* (Optionally) operates on the entire group chunk. If this is supported, a
fast path is used starting from the *second* chunk.

.. deprecated:: 1.5.0

When using ``.transform`` on a grouped DataFrame and the transformation function
returns a DataFrame, currently pandas does not align the result's index
with the input's index. This behavior is deprecated and alignment will
be performed in a future version of pandas. You can apply ``.to_numpy()`` to the
result of the transformation function to avoid alignment.

Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the
transformation function. If the results from different groups have different dtypes, then
a common dtype will be determined in the same way as ``DataFrame`` construction.
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ Other Deprecations
- Deprecated the ``closed`` argument in :class:`intervaltree` in favor of ``inclusive`` argument; In a future version passing ``closed`` will raise (:issue:`40245`)
- Deprecated the ``closed`` argument in :class:`ArrowInterval` in favor of ``inclusive`` argument; In a future version passing ``closed`` will raise (:issue:`40245`)
- Deprecated allowing ``unit="M"`` or ``unit="Y"`` in :class:`Timestamp` constructor with a non-round float value (:issue:`47267`)
- Deprecated :meth:`DataFrameGroupBy.transform` not aligning the result when the UDF returned DataFrame (:issue:`45648`)
- Deprecated the ``display.column_space`` global configuration option (:issue:`7576`)
-

Expand Down
19 changes: 19 additions & 0 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,14 +1196,33 @@ def _transform_general(self, func, *args, **kwargs):
applied.append(res)

# Compute and process with the remaining groups
emit_alignment_warning = False
for name, group in gen:
if group.size == 0:
continue
object.__setattr__(group, "name", name)
res = path(group)
if (
not emit_alignment_warning
and res.ndim == 2
and not res.index.equals(group.index)
):
emit_alignment_warning = True

res = _wrap_transform_general_frame(self.obj, group, res)
applied.append(res)

if emit_alignment_warning:
# GH#45648
warnings.warn(
"In a future version of pandas, returning a DataFrame in "
"groupby.transform will align with the input's index. Apply "
"`.to_numpy()` to the result in the transform function to keep "
"the current behavior and silence this warning.",
FutureWarning,
stacklevel=find_stack_level(),
)

concat_index = obj.columns if self.axis == 0 else obj.index
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
concatenated = concat(applied, axis=self.axis, verify_integrity=False)
Expand Down
12 changes: 10 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,14 @@ class providing the base-class of operations.
"""

_transform_template = """
Call function producing a like-indexed %(klass)s on each group and
Call function producing a same-indexed %(klass)s on each group and
return a %(klass)s having the same indexes as the original object
filled with the transformed values.

Parameters
----------
f : function
Function to apply to each group.
Function to apply to each group. See the Notes section below for requirements.

Can also accept a Numba JIT function with
``engine='numba'`` specified.
Expand Down Expand Up @@ -451,6 +451,14 @@ class providing the base-class of operations.
The resulting dtype will reflect the return value of the passed ``func``,
see the examples below.

.. deprecated:: 1.5.0

When using ``.transform`` on a grouped DataFrame and the transformation function
returns a DataFrame, currently pandas does not align the result's index
with the input's index. This behavior is deprecated and alignment will
be performed in a future version of pandas. You can apply ``.to_numpy()`` to the
result of the transformation function to avoid alignment.

Examples
--------

Expand Down
39 changes: 39 additions & 0 deletions pandas/tests/groupby/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,3 +1531,42 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func)
result = gb.transform(transformation_func, *args)

tm.assert_equal(result, expected)


@pytest.mark.parametrize(
"func, series, expected_values",
[
(Series.sort_values, False, [4, 5, 3, 1, 2]),
(lambda x: x.head(1), False, ValueError),
# SeriesGroupBy already has correct behavior
(Series.sort_values, True, [5, 4, 3, 2, 1]),
(lambda x: x.head(1), True, [5.0, np.nan, 3.0, 2.0, np.nan]),
],
)
@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]])
@pytest.mark.parametrize("keys_in_index", [True, False])
def test_transform_aligns_depr(func, series, expected_values, keys, keys_in_index):
# GH#45648 - transform should align with the input's index
df = DataFrame({"a1": [1, 1, 3, 2, 2], "b": [5, 4, 3, 2, 1]})
if "a2" in keys:
df["a2"] = df["a1"]
if keys_in_index:
df = df.set_index(keys, append=True)

gb = df.groupby(keys)
if series:
gb = gb["b"]

warn = None if series else FutureWarning
msg = "returning a DataFrame in groupby.transform will align"
if expected_values is ValueError:
with tm.assert_produces_warning(warn, match=msg):
with pytest.raises(ValueError, match="Length mismatch"):
gb.transform(func)
else:
with tm.assert_produces_warning(warn, match=msg):
result = gb.transform(func)
expected = DataFrame({"b": expected_values}, index=df.index)
if series:
expected = expected["b"]
tm.assert_equal(result, expected)