Skip to content

Commit f4e5b26

Browse files
authored
feat: enhance read_csv index_col parameter support (#1631)
This PR expands the functionality of the `index_col` parameter in the `read_csv` method. New capabilities include: 1. **Multi-column Indexing:** `index_col` now accepts an iterable of strings (column names) to create a MultiIndex. (Fixes internal issue 338089659) 2. **Integer Indexing:** Support for a single integer index or an iterable of integers (column positions) is also explicitly included/verified. (Fixes internal issue 404530013) 3. **Pandas Compatibility:** Adds tests to ensure that the behavior of `index_col` when set to `False`, `None`, or `True` aligns with standard Pandas behavior. (Fixes internal issue 338400133)
1 parent 233347a commit f4e5b26

File tree

5 files changed

+100
-58
lines changed

5 files changed

+100
-58
lines changed

bigframes/session/__init__.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -961,37 +961,13 @@ def _read_csv_w_bigquery_engine(
961961
f"{constants.FEEDBACK_LINK}"
962962
)
963963

964-
# TODO(b/338089659): Looks like we can relax this 1 column
965-
# restriction if we check the contents of an iterable are strings
966-
# not integers.
967-
if (
968-
# Empty tuples, None, and False are allowed and falsey.
969-
index_col
970-
and not isinstance(index_col, bigframes.enums.DefaultIndexKind)
971-
and not isinstance(index_col, str)
972-
):
973-
raise NotImplementedError(
974-
"BigQuery engine only supports a single column name for `index_col`, "
975-
f"got: {repr(index_col)}. {constants.FEEDBACK_LINK}"
976-
)
964+
if index_col is True:
965+
raise ValueError("The value of index_col couldn't be 'True'")
977966

978967
# None and False cannot be passed to read_gbq.
979-
# TODO(b/338400133): When index_col is None, we should be using the
980-
# first column of the CSV as the index to be compatible with the
981-
# pandas engine. According to the pandas docs, only "False"
982-
# indicates a default sequential index.
983-
if not index_col:
968+
if index_col is None or index_col is False:
984969
index_col = ()
985970

986-
index_col = typing.cast(
987-
Union[
988-
Sequence[str], # Falsey values
989-
bigframes.enums.DefaultIndexKind,
990-
str,
991-
],
992-
index_col,
993-
)
994-
995971
# usecols should only be an iterable of strings (column names) for use as columns in read_gbq.
996972
columns: Tuple[Any, ...] = tuple()
997973
if usecols is not None:

bigframes/session/_io/bigquery/read_gbq_table.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def _is_table_clustered_or_partitioned(
230230

231231
def get_index_cols(
232232
table: bigquery.table.Table,
233-
index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind,
233+
index_col: Iterable[str]
234+
| str
235+
| Iterable[int]
236+
| int
237+
| bigframes.enums.DefaultIndexKind,
234238
) -> List[str]:
235239
"""
236240
If we can get a total ordering from the table, such as via primary key
@@ -240,6 +244,8 @@ def get_index_cols(
240244

241245
# Transform index_col -> index_cols so we have a variable that is
242246
# always a list of column names (possibly empty).
247+
schema_len = len(table.schema)
248+
index_cols: List[str] = []
243249
if isinstance(index_col, bigframes.enums.DefaultIndexKind):
244250
if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
245251
# User has explicity asked for a default, sequential index.
@@ -255,9 +261,35 @@ def get_index_cols(
255261
f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}"
256262
)
257263
elif isinstance(index_col, str):
258-
index_cols: List[str] = [index_col]
264+
index_cols = [index_col]
265+
elif isinstance(index_col, int):
266+
if not 0 <= index_col < schema_len:
267+
raise ValueError(
268+
f"Integer index {index_col} is out of bounds "
269+
f"for table with {schema_len} columns (must be >= 0 and < {schema_len})."
270+
)
271+
index_cols = [table.schema[index_col].name]
272+
elif isinstance(index_col, Iterable):
273+
for item in index_col:
274+
if isinstance(item, str):
275+
index_cols.append(item)
276+
elif isinstance(item, int):
277+
if not 0 <= item < schema_len:
278+
raise ValueError(
279+
f"Integer index {item} is out of bounds "
280+
f"for table with {schema_len} columns (must be >= 0 and < {schema_len})."
281+
)
282+
index_cols.append(table.schema[item].name)
283+
else:
284+
raise TypeError(
285+
"If index_col is an iterable, it must contain either strings "
286+
"(column names) or integers (column positions)."
287+
)
259288
else:
260-
index_cols = list(index_col)
289+
raise TypeError(
290+
f"Unsupported type for index_col: {type(index_col).__name__}. Expected"
291+
"an integer, an string, an iterable of strings, or an iterable of integers."
292+
)
261293

262294
# If the isn't an index selected, use the primary keys of the table as the
263295
# index. If there are no primary keys, we'll return an empty list.

bigframes/session/loader.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ def read_gbq_table(
289289
self,
290290
query: str,
291291
*,
292-
index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = (),
292+
index_col: Iterable[str]
293+
| str
294+
| Iterable[int]
295+
| int
296+
| bigframes.enums.DefaultIndexKind = (),
293297
columns: Iterable[str] = (),
294298
max_results: Optional[int] = None,
295299
api_name: str = "read_gbq_table",
@@ -516,7 +520,11 @@ def read_bigquery_load_job(
516520
filepath_or_buffer: str | IO["bytes"],
517521
*,
518522
job_config: bigquery.LoadJobConfig,
519-
index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = (),
523+
index_col: Iterable[str]
524+
| str
525+
| Iterable[int]
526+
| int
527+
| bigframes.enums.DefaultIndexKind = (),
520528
columns: Iterable[str] = (),
521529
) -> dataframe.DataFrame:
522530
# Need to create session table beforehand

tests/system/small/test_session.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,55 +1216,86 @@ def test_read_csv_for_local_file_w_sep(session, df_and_local_csv, sep):
12161216
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
12171217

12181218

1219-
def test_read_csv_w_index_col_false(session, df_and_local_csv):
1219+
@pytest.mark.parametrize(
1220+
"index_col",
1221+
[
1222+
pytest.param(None, id="none"),
1223+
pytest.param(False, id="false"),
1224+
pytest.param([], id="empty_list"),
1225+
],
1226+
)
1227+
def test_read_csv_for_index_col_w_false(session, df_and_local_csv, index_col):
12201228
# Compares results for pandas and bigframes engines
12211229
scalars_df, path = df_and_local_csv
12221230
with open(path, "rb") as buffer:
12231231
bf_df = session.read_csv(
12241232
buffer,
12251233
engine="bigquery",
1226-
index_col=False,
1234+
index_col=index_col,
12271235
)
12281236
with open(path, "rb") as buffer:
12291237
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
12301238
pd_df = session.read_csv(
1231-
buffer, index_col=False, dtype=scalars_df.dtypes.to_dict()
1239+
buffer, index_col=index_col, dtype=scalars_df.dtypes.to_dict()
12321240
)
12331241

1234-
assert bf_df.shape[0] == scalars_df.shape[0]
1235-
assert bf_df.shape[0] == pd_df.shape[0]
1236-
1237-
# We use a default index because of index_col=False, so the previous index
1238-
# column is just loaded as a column.
1239-
assert len(bf_df.columns) == len(scalars_df.columns) + 1
1240-
assert len(bf_df.columns) == len(pd_df.columns)
1242+
assert bf_df.shape == pd_df.shape
12411243

12421244
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
12431245
# (b/280889935) or guarantee row ordering.
12441246
bf_df = bf_df.set_index("rowindex").sort_index()
12451247
pd_df = pd_df.set_index("rowindex")
1246-
1247-
pd.testing.assert_frame_equal(bf_df.to_pandas(), scalars_df.to_pandas())
12481248
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
12491249

12501250

1251-
def test_read_csv_w_index_col_column_label(session, df_and_gcs_csv):
1252-
scalars_df, path = df_and_gcs_csv
1253-
bf_df = session.read_csv(path, engine="bigquery", index_col="rowindex")
1251+
@pytest.mark.parametrize(
1252+
"index_col",
1253+
[
1254+
pytest.param("rowindex", id="single_str"),
1255+
pytest.param(["rowindex", "bool_col"], id="multi_str"),
1256+
pytest.param(0, id="single_int"),
1257+
pytest.param([0, 2], id="multi_int"),
1258+
pytest.param([0, "bool_col"], id="mix_types"),
1259+
],
1260+
)
1261+
def test_read_csv_for_index_col(session, df_and_gcs_csv, index_col):
1262+
scalars_pandas_df, path = df_and_gcs_csv
1263+
bf_df = session.read_csv(path, engine="bigquery", index_col=index_col)
12541264

12551265
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
12561266
pd_df = session.read_csv(
1257-
path, index_col="rowindex", dtype=scalars_df.dtypes.to_dict()
1267+
path, index_col=index_col, dtype=scalars_pandas_df.dtypes.to_dict()
12581268
)
12591269

1260-
assert bf_df.shape == scalars_df.shape
12611270
assert bf_df.shape == pd_df.shape
1271+
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
12621272

1263-
assert len(bf_df.columns) == len(scalars_df.columns)
1264-
assert len(bf_df.columns) == len(pd_df.columns)
12651273

1266-
pd.testing.assert_frame_equal(bf_df.to_pandas(), scalars_df.to_pandas())
1267-
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
1274+
@pytest.mark.parametrize(
1275+
("index_col", "error_type", "error_msg"),
1276+
[
1277+
pytest.param(
1278+
True, ValueError, "The value of index_col couldn't be 'True'", id="true"
1279+
),
1280+
pytest.param(100, ValueError, "out of bounds", id="single_int"),
1281+
pytest.param([0, 200], ValueError, "out of bounds", id="multi_int"),
1282+
pytest.param(
1283+
[0.1], TypeError, "it must contain either strings", id="invalid_iterable"
1284+
),
1285+
pytest.param(
1286+
3.14, TypeError, "Unsupported type for index_col", id="unsupported_type"
1287+
),
1288+
],
1289+
)
1290+
def test_read_csv_raises_error_for_invalid_index_col(
1291+
session, df_and_gcs_csv, index_col, error_type, error_msg
1292+
):
1293+
_, path = df_and_gcs_csv
1294+
with pytest.raises(
1295+
error_type,
1296+
match=error_msg,
1297+
):
1298+
session.read_csv(path, engine="bigquery", index_col=index_col)
12681299

12691300

12701301
@pytest.mark.parametrize(

tests/unit/session/test_session.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,6 @@
118118
"BigQuery engine does not support these arguments",
119119
id="with_dtype",
120120
),
121-
pytest.param(
122-
{"engine": "bigquery", "index_col": 5},
123-
"BigQuery engine only supports a single column name for `index_col`.",
124-
id="with_index_col_not_str",
125-
),
126121
pytest.param(
127122
{"engine": "bigquery", "usecols": [1, 2]},
128123
"BigQuery engine only supports an iterable of strings for `usecols`.",

0 commit comments

Comments
 (0)