Skip to content

[pandas.tools.plotting._subplots] squeeze option is not respected if ax is provided #16253

Closed
@Bolayniuss

Description

@Bolayniuss

pandas.tools.plotting._subplots returns a flatten array for the axes if ax parameter is provided regardless the value of the parameter squeeze. If ax is a 2d list, ax is not flattened but then the test len(ax) == naxes fails. The test should be ax.size == naxes.

This issue forbid the use of the ax parameter in plotting function like scatter_matrix which expect ax to be a NxN array.

import pandas as pd
from pandas.tools.plotting import scatter_matrix
import matplotlib.pyplot as plt

df = pd.DataFrame(dict(a=[0, 1, 2, 3, 4, 5], b=5, 6, 7, 8, 9])

f, axes = plt.subplots(2, 2)
scatter_matrix(df, ax=axes)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-10-8860714df14f> in <module>()
----> 1 scatter_matrix(df, ax=axes)

/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in scatter_matrix(frame, alpha, figsize, ax, grid, diagonal, marker, density_kwds, hist_kwds, range_padding, **kwds)
    371     for i, a in zip(lrange(n), df.columns):
    372         for j, b in zip(lrange(n), df.columns):
--> 373             ax = axes[i, j]
    374 
    375             if i == j:

IndexError: too many indices for array

and if axes is a 2d list

scatter_matrix(df, ax=axes.tolist())
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-aeac24929f47> in <module>()
----> 1 scatter_matrix(df, ax=axes.tolist())

/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in scatter_matrix(frame, alpha, figsize, ax, grid, diagonal, marker, density_kwds, hist_kwds, range_padding, **kwds)
    347     naxes = n * n
    348     fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax,
--> 349                           squeeze=False)
    350 
    351     # no gaps between subplots

/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in _subplots(naxes, sharex, sharey, squeeze, subplot_kw, ax, layout, layout_type, **fig_kw)
   3389             else:
   3390                 raise ValueError("The number of passed axes must be {0}, the "
-> 3391                                  "same as the output plot".format(naxes))
   3392 
   3393         fig = ax.get_figure()

ValueError: The number of passed axes must be 4, the same as the output plot

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions