Skip to content

ENH: improved dtype inference for Index.map #44609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ Other enhancements
``USFederalHolidayCalendar``. See also `Other API changes`_.
- :meth:`.Rolling.var`, :meth:`.Expanding.var`, :meth:`.Rolling.std`, :meth:`.Expanding.std` now support `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`44461`)
- :meth:`Series.info` has been added, for compatibility with :meth:`DataFrame.info` (:issue:`5167`)
- :meth:`UInt64Index.map` now retains ``dtype`` where possible (:issue:`44609`)
-


.. ---------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
can_hold_element,
find_common_type,
infer_dtype_from,
maybe_cast_pointwise_result,
validate_numeric_casting,
)
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -5977,6 +5978,15 @@ def map(self, mapper, na_action=None):
# empty
dtype = self.dtype

# e.g. if we are floating and new_values is all ints, then we
# don't want to cast back to floating. But if we are UInt64
# and new_values is all ints, we want to try.
same_dtype = lib.infer_dtype(new_values, skipna=False) == self.inferred_type
if same_dtype:
new_values = maybe_cast_pointwise_result(
new_values, self.dtype, same_dtype=same_dtype
)

if self._is_backward_compat_public_numeric_index and is_numeric_dtype(
new_values.dtype
):
Expand Down
30 changes: 8 additions & 22 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

from pandas.core.dtypes.common import (
is_datetime64tz_dtype,
is_float_dtype,
is_integer_dtype,
is_unsigned_integer_dtype,
)
from pandas.core.dtypes.dtypes import CategoricalDtype

Expand Down Expand Up @@ -557,20 +555,9 @@ def test_map(self, simple_index):
# callable
idx = simple_index

# we don't infer UInt64
if is_integer_dtype(idx.dtype):
expected = idx.astype("int64")
elif is_float_dtype(idx.dtype):
expected = idx.astype("float64")
if idx._is_backward_compat_public_numeric_index:
# We get a NumericIndex back, not Float64Index
expected = type(idx)(expected)
else:
expected = idx

result = idx.map(lambda x: x)
# For RangeIndex we convert to Int64Index
tm.assert_index_equal(result, expected, exact="equiv")
tm.assert_index_equal(result, idx, exact="equiv")

@pytest.mark.parametrize(
"mapper",
Expand All @@ -583,27 +570,26 @@ def test_map_dictlike(self, mapper, simple_index):

idx = simple_index
if isinstance(idx, CategoricalIndex):
# TODO(2.0): see if we can avoid skipping once
# CategoricalIndex.reindex is removed.
pytest.skip(f"skipping tests for {type(idx)}")

identity = mapper(idx.values, idx)

# we don't infer to UInt64 for a dict
if is_unsigned_integer_dtype(idx.dtype) and isinstance(identity, dict):
expected = idx.astype("int64")
else:
expected = idx

result = idx.map(identity)
# For RangeIndex we convert to Int64Index
tm.assert_index_equal(result, expected, exact="equiv")
tm.assert_index_equal(result, idx, exact="equiv")

# empty mappable
dtype = None
if idx._is_backward_compat_public_numeric_index:
new_index_cls = NumericIndex
if idx.dtype.kind == "f":
dtype = idx.dtype
else:
new_index_cls = Float64Index

expected = new_index_cls([np.nan] * len(idx))
expected = new_index_cls([np.nan] * len(idx), dtype=dtype)
result = idx.map(mapper(expected, idx))
tm.assert_index_equal(result, expected)

Expand Down
8 changes: 1 addition & 7 deletions pandas/tests/indexes/multi/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,8 @@ def test_map(idx):
# callable
index = idx

# we don't infer UInt64
if isinstance(index, UInt64Index):
expected = index.astype("int64")
else:
expected = index

result = index.map(lambda x: x)
tm.assert_index_equal(result, expected)
tm.assert_index_equal(result, index)


@pytest.mark.parametrize(
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/indexes/numeric/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,20 @@ def test_float64_index_equals():

result = string_index.equals(float_index)
assert result is False


def test_map_dtype_inference_unsigned_to_signed():
# GH#44609 cases where we don't retain dtype
idx = UInt64Index([1, 2, 3])
result = idx.map(lambda x: -x)
expected = Int64Index([-1, -2, -3])
tm.assert_index_equal(result, expected)


def test_map_dtype_inference_overflows():
# GH#44609 case where we have to upcast
idx = NumericIndex(np.array([1, 2, 3], dtype=np.int8))
result = idx.map(lambda x: x * 1000)
# TODO: we could plausibly try to infer down to int16 here
expected = NumericIndex([1000, 2000, 3000], dtype=np.int64)
tm.assert_index_equal(result, expected)
14 changes: 1 addition & 13 deletions pandas/tests/indexes/test_any_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
"""
import re

import numpy as np
import pytest

from pandas.core.dtypes.common import is_float_dtype

import pandas._testing as tm


Expand Down Expand Up @@ -49,16 +46,7 @@ def test_mutability(index):
def test_map_identity_mapping(index):
# GH#12766
result = index.map(lambda x: x)
if index._is_backward_compat_public_numeric_index:
if is_float_dtype(index.dtype):
expected = index.astype(np.float64)
elif index.dtype == np.uint64:
expected = index.astype(np.uint64)
else:
expected = index.astype(np.int64)
else:
expected = index
tm.assert_index_equal(result, expected, exact="equiv")
tm.assert_index_equal(result, index, exact="equiv")


def test_wrong_number_names(index):
Expand Down
15 changes: 9 additions & 6 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
period_range,
)
import pandas._testing as tm
from pandas.api.types import is_float_dtype
from pandas.core.api import (
Float64Index,
Int64Index,
Expand Down Expand Up @@ -535,11 +534,15 @@ def test_map_dictlike(self, index, mapper):
# to match proper result coercion for uints
expected = Index([])
elif index._is_backward_compat_public_numeric_index:
if is_float_dtype(index.dtype):
exp_dtype = np.float64
else:
exp_dtype = np.int64
expected = index._constructor(np.arange(len(index), 0, -1), dtype=exp_dtype)
expected = index._constructor(
np.arange(len(index), 0, -1), dtype=index.dtype
)
elif type(index) is Index and index.dtype != object:
# i.e. EA-backed, for now just Nullable
expected = Index(np.arange(len(index), 0, -1), dtype=index.dtype)
elif index.dtype.kind == "u":
# TODO: case where e.g. we cannot hold result in UInt8?
expected = Index(np.arange(len(index), 0, -1), dtype=index.dtype)
else:
expected = Index(np.arange(len(index), 0, -1))

Expand Down