Skip to content

Commit 952cab9

Browse files
authored
feat: add ml.model_selection.KFold class (#1001)
1 parent 6b34244 commit 952cab9

File tree

4 files changed

+367
-28
lines changed

4 files changed

+367
-28
lines changed

bigframes/ml/model_selection.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.model_selection."""
1818

1919

20-
from typing import cast, List, Union
20+
import inspect
21+
from typing import cast, Generator, List, Union
2122

23+
import bigframes_vendored.sklearn.model_selection._split as vendored_model_selection_split
24+
25+
from bigframes.core import log_adapter
2226
from bigframes.ml import utils
2327
import bigframes.pandas as bpd
2428

@@ -30,30 +34,6 @@ def train_test_split(
3034
random_state: Union[int, None] = None,
3135
stratify: Union[bpd.Series, None] = None,
3236
) -> List[Union[bpd.DataFrame, bpd.Series]]:
33-
"""Splits dataframes or series into random train and test subsets.
34-
35-
Args:
36-
*arrays (bigframes.dataframe.DataFrame or bigframes.series.Series):
37-
A sequence of BigQuery DataFrames or Series that can be joined on
38-
their indexes.
39-
test_size (default None):
40-
The proportion of the dataset to include in the test split. If
41-
None, this will default to the complement of train_size. If both
42-
are none, it will be set to 0.25.
43-
train_size (default None):
44-
The proportion of the dataset to include in the train split. If
45-
None, this will default to the complement of test_size.
46-
random_state (default None):
47-
A seed to use for randomly choosing the rows of the split. If not
48-
set, a random split will be generated each time.
49-
stratify: (bigframes.series.Series or None, default None):
50-
If not None, data is split in a stratified fashion, using this as the class labels. Each split has the same distribution of the class labels with the original dataset.
51-
Default to None.
52-
Note: By setting the stratify parameter, the memory consumption and generated SQL will be linear to the unique values in the Series. May return errors if the unique values size is too large.
53-
54-
Returns:
55-
List[Union[bigframes.dataframe.DataFrame, bigframes.series.Series]]: A list of BigQuery DataFrames or Series.
56-
"""
5737

5838
# TODO(garrettwu): scikit-learn throws an error when the dataframes don't have the same
5939
# number of rows. We probably want to do something similar. Now the implementation is based
@@ -123,3 +103,47 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra
123103
results.append(joined_df_test[columns])
124104

125105
return results
106+
107+
108+
train_test_split.__doc__ = inspect.getdoc(
109+
vendored_model_selection_split.train_test_split
110+
)
111+
112+
113+
@log_adapter.class_logger
114+
class KFold(vendored_model_selection_split.KFold):
115+
def __init__(self, n_splits: int = 5, *, random_state: Union[int, None] = None):
116+
if n_splits < 2:
117+
raise ValueError(f"n_splits must be at least 2. Got {n_splits}")
118+
self._n_splits = n_splits
119+
self._random_state = random_state
120+
121+
def get_n_splits(self) -> int:
122+
return self._n_splits
123+
124+
def split(
125+
self,
126+
X: Union[bpd.DataFrame, bpd.Series],
127+
y: Union[bpd.DataFrame, bpd.Series, None] = None,
128+
) -> Generator[tuple[Union[bpd.DataFrame, bpd.Series, None]], None, None]:
129+
X_df = next(utils.convert_to_dataframe(X))
130+
y_df_or = next(utils.convert_to_dataframe(y)) if y is not None else None
131+
joined_df = X_df.join(y_df_or, how="outer") if y_df_or is not None else X_df
132+
133+
fracs = (1 / self._n_splits,) * self._n_splits
134+
135+
dfs = joined_df._split(fracs=fracs, random_state=self._random_state)
136+
137+
for i in range(len(dfs)):
138+
train_df = bpd.concat(dfs[:i] + dfs[i + 1 :])
139+
test_df = dfs[i]
140+
141+
X_train = train_df[X_df.columns]
142+
y_train = train_df[y_df_or.columns] if y_df_or is not None else None
143+
144+
X_test = test_df[X_df.columns]
145+
y_test = test_df[y_df_or.columns] if y_df_or is not None else None
146+
147+
yield utils.convert_to_types(
148+
[X_train, X_test, y_train, y_test], [X, X, y, y]
149+
)

bigframes/ml/utils.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import typing
16-
from typing import Any, Iterable, Literal, Mapping, Optional, Union
16+
from typing import Any, Generator, Iterable, Literal, Mapping, Optional, Union
1717

1818
import bigframes_vendored.constants as constants
1919
from google.cloud import bigquery
@@ -25,7 +25,7 @@
2525
ArrayType = Union[bpd.DataFrame, bpd.Series]
2626

2727

28-
def convert_to_dataframe(*input: ArrayType) -> Iterable[bpd.DataFrame]:
28+
def convert_to_dataframe(*input: ArrayType) -> Generator[bpd.DataFrame, None, None]:
2929
return (_convert_to_dataframe(frame) for frame in input)
3030

3131

@@ -39,7 +39,7 @@ def _convert_to_dataframe(frame: ArrayType) -> bpd.DataFrame:
3939
)
4040

4141

42-
def convert_to_series(*input: ArrayType) -> Iterable[bpd.Series]:
42+
def convert_to_series(*input: ArrayType) -> Generator[bpd.Series, None, None]:
4343
return (_convert_to_series(frame) for frame in input)
4444

4545

@@ -60,6 +60,39 @@ def _convert_to_series(frame: ArrayType) -> bpd.Series:
6060
)
6161

6262

63+
def convert_to_types(
64+
inputs: Iterable[Union[ArrayType, None]],
65+
type_instances: Iterable[Union[ArrayType, None]],
66+
) -> tuple[Union[ArrayType, None]]:
67+
"""Convert the DF, Series and None types of the input to corresponding type_instances types."""
68+
results = []
69+
for input, type_instance in zip(inputs, type_instances):
70+
results.append(_convert_to_type(input, type_instance))
71+
return tuple(results)
72+
73+
74+
def _convert_to_type(
75+
input: Union[ArrayType, None], type_instance: Union[ArrayType, None]
76+
):
77+
if type_instance is None:
78+
if input is not None:
79+
raise ValueError(
80+
f"Trying to convert not None type to None. {constants.FEEDBACK_LINK}"
81+
)
82+
return None
83+
if input is None:
84+
raise ValueError(
85+
f"Trying to convert None type to not None. {constants.FEEDBACK_LINK}"
86+
)
87+
if isinstance(type_instance, bpd.DataFrame):
88+
return _convert_to_dataframe(input)
89+
if isinstance(type_instance, bpd.Series):
90+
return _convert_to_series(input)
91+
raise ValueError(
92+
f"Unsupport converting to {type(type_instance)}. {constants.FEEDBACK_LINK}"
93+
)
94+
95+
6396
def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
6497
"""Parse model endpoint string to model_name and version."""
6598
model_name = model_endpoint

tests/system/small/ml/test_model_selection.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
16+
1517
import pandas as pd
1618
import pytest
1719

@@ -302,3 +304,174 @@ def test_train_test_split_stratify(df_fixture, request):
302304
test_counts,
303305
check_index_type=False,
304306
)
307+
308+
309+
@pytest.mark.parametrize(
310+
"n_splits",
311+
(3, 5, 10),
312+
)
313+
def test_KFold_get_n_splits(n_splits):
314+
kf = model_selection.KFold(n_splits)
315+
assert kf.get_n_splits() == n_splits
316+
317+
318+
@pytest.mark.parametrize(
319+
"df_fixture",
320+
("penguins_df_default_index", "penguins_df_null_index"),
321+
)
322+
@pytest.mark.parametrize(
323+
"n_splits",
324+
(3, 5),
325+
)
326+
def test_KFold_split(df_fixture, n_splits, request):
327+
df = request.getfixturevalue(df_fixture)
328+
329+
kf = model_selection.KFold(n_splits=n_splits)
330+
331+
X = df[
332+
[
333+
"species",
334+
"island",
335+
"culmen_length_mm",
336+
]
337+
]
338+
y = df["body_mass_g"]
339+
340+
len_test_upper, len_test_lower = math.ceil(len(df) / n_splits), math.floor(
341+
len(df) / n_splits
342+
)
343+
len_train_upper, len_train_lower = (
344+
len(df) - len_test_lower,
345+
len(df) - len_test_upper,
346+
)
347+
348+
for X_train, X_test, y_train, y_test in kf.split(X, y): # type: ignore
349+
assert isinstance(X_train, bpd.DataFrame)
350+
assert isinstance(X_test, bpd.DataFrame)
351+
assert isinstance(y_train, bpd.Series)
352+
assert isinstance(y_test, bpd.Series)
353+
354+
# Depend on the iteration, train/test can +-1 in size.
355+
assert (
356+
X_train.shape == (len_train_upper, 3)
357+
and y_train.shape == (len_train_upper,)
358+
and X_test.shape == (len_test_lower, 3)
359+
and y_test.shape == (len_test_lower,)
360+
) or (
361+
X_train.shape == (len_train_lower, 3)
362+
and y_train.shape == (len_train_lower,)
363+
and X_test.shape == (len_test_upper, 3)
364+
and y_test.shape == (len_test_upper,)
365+
)
366+
367+
368+
@pytest.mark.parametrize(
369+
"df_fixture",
370+
("penguins_df_default_index", "penguins_df_null_index"),
371+
)
372+
@pytest.mark.parametrize(
373+
"n_splits",
374+
(3, 5),
375+
)
376+
def test_KFold_split_X_only(df_fixture, n_splits, request):
377+
df = request.getfixturevalue(df_fixture)
378+
379+
kf = model_selection.KFold(n_splits=n_splits)
380+
381+
X = df[
382+
[
383+
"species",
384+
"island",
385+
"culmen_length_mm",
386+
]
387+
]
388+
389+
len_test_upper, len_test_lower = math.ceil(len(df) / n_splits), math.floor(
390+
len(df) / n_splits
391+
)
392+
len_train_upper, len_train_lower = (
393+
len(df) - len_test_lower,
394+
len(df) - len_test_upper,
395+
)
396+
397+
for X_train, X_test, y_train, y_test in kf.split(X, y=None): # type: ignore
398+
assert isinstance(X_train, bpd.DataFrame)
399+
assert isinstance(X_test, bpd.DataFrame)
400+
assert y_train is None
401+
assert y_test is None
402+
403+
# Depend on the iteration, train/test can +-1 in size.
404+
assert (
405+
X_train.shape == (len_train_upper, 3)
406+
and X_test.shape == (len_test_lower, 3)
407+
) or (
408+
X_train.shape == (len_train_lower, 3)
409+
and X_test.shape == (len_test_upper, 3)
410+
)
411+
412+
413+
def test_KFold_seeded_correct_rows(session, penguins_pandas_df_default_index):
414+
kf = model_selection.KFold(random_state=42)
415+
# Note that we're using `penguins_pandas_df_default_index` as this test depends
416+
# on a stable row order being present end to end
417+
# filter down to the chunkiest penguins, to keep our test code a reasonable size
418+
all_data = penguins_pandas_df_default_index[
419+
penguins_pandas_df_default_index.body_mass_g > 5500
420+
]
421+
422+
# Note that bigframes loses the index if it doesn't have a name
423+
all_data.index.name = "rowindex"
424+
425+
df = session.read_pandas(all_data)
426+
427+
X = df[
428+
[
429+
"species",
430+
"island",
431+
"culmen_length_mm",
432+
]
433+
]
434+
y = df["body_mass_g"]
435+
X_train, X_test, y_train, y_test = next(kf.split(X, y)) # type: ignore
436+
437+
X_train_sorted = X_train.to_pandas().sort_index()
438+
X_test_sorted = X_test.to_pandas().sort_index()
439+
y_train_sorted = y_train.to_pandas().sort_index()
440+
y_test_sorted = y_test.to_pandas().sort_index()
441+
442+
train_index: pd.Index = pd.Index(
443+
[
444+
144,
445+
146,
446+
148,
447+
161,
448+
168,
449+
183,
450+
217,
451+
221,
452+
225,
453+
226,
454+
237,
455+
244,
456+
257,
457+
262,
458+
264,
459+
266,
460+
267,
461+
269,
462+
278,
463+
289,
464+
290,
465+
291,
466+
],
467+
dtype="Int64",
468+
name="rowindex",
469+
)
470+
test_index: pd.Index = pd.Index(
471+
[186, 240, 245, 260, 263, 268], dtype="Int64", name="rowindex"
472+
)
473+
474+
pd.testing.assert_index_equal(X_train_sorted.index, train_index)
475+
pd.testing.assert_index_equal(X_test_sorted.index, test_index)
476+
pd.testing.assert_index_equal(y_train_sorted.index, train_index)
477+
pd.testing.assert_index_equal(y_test_sorted.index, test_index)

0 commit comments

Comments
 (0)