Skip to content

Commit 71d0bd4

Browse files
authored
feat: column specs for tabular transformation (#466)
- adds column_specs as an alternative to column_transformation - adds training_jobs.AutoMLTabularTrainingJob.get_auto_column_specs - adds aiplatform.column
1 parent e8121ad commit 71d0bd4

File tree

3 files changed

+333
-8
lines changed

3 files changed

+333
-8
lines changed

google/cloud/aiplatform/datasets/time_series_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create(
4646
encryption_spec_key_name: Optional[str] = None,
4747
sync: bool = True,
4848
) -> "TimeSeriesDataset":
49-
"""Creates a new tabular dataset.
49+
"""Creates a new time series dataset.
5050
5151
Args:
5252
display_name (str):

google/cloud/aiplatform/training_jobs.py

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import datetime
1919
import time
2020
from typing import Dict, List, Optional, Sequence, Tuple, Union
21+
import warnings
2122

2223
import abc
2324

@@ -2525,6 +2526,7 @@ def __init__(
25252526
display_name: str,
25262527
optimization_prediction_type: str,
25272528
optimization_objective: Optional[str] = None,
2529+
column_specs: Optional[Dict[str, str]] = None,
25282530
column_transformations: Optional[Union[Dict, List[Dict]]] = None,
25292531
optimization_objective_recall_value: Optional[float] = None,
25302532
optimization_objective_precision_value: Optional[float] = None,
@@ -2536,6 +2538,15 @@ def __init__(
25362538
):
25372539
"""Constructs a AutoML Tabular Training Job.
25382540
2541+
Example usage:
2542+
2543+
job = training_jobs.AutoMLTabularTrainingJob(
2544+
display_name="my_display_name",
2545+
optimization_prediction_type="classification",
2546+
optimization_objective="minimize-log-loss",
2547+
column_specs={"column_1": "auto", "column_2": "numeric"},
2548+
)
2549+
25392550
Args:
25402551
display_name (str):
25412552
Required. The user-defined name of this TrainingPipeline.
@@ -2576,15 +2587,29 @@ def __init__(
25762587
"minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
25772588
"minimize-mae" - Minimize mean-absolute error (MAE).
25782589
"minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
2579-
column_transformations (Optional[Union[Dict, List[Dict]]]):
2590+
column_specs (Dict[str, str]):
2591+
Optional. Alternative to column_transformations where the keys of the dict
2592+
are column names and their respective values are one of
2593+
AutoMLTabularTrainingJob.column_data_types.
2594+
When creating transformation for BigQuery Struct column, the column
2595+
should be flattened using "." as the delimiter. Only columns with no child
2596+
should have a transformation.
2597+
If an input column has no transformations on it, such a column is
2598+
ignored by the training, except for the targetColumn, which should have
2599+
no transformations defined on.
2600+
Only one of column_transformations or column_specs should be passed.
2601+
column_transformations (Union[Dict, List[Dict]]):
25802602
Optional. Transformations to apply to the input columns (i.e. columns other
25812603
than the targetColumn). Each transformation may produce multiple
25822604
result values from the column's value, and all are used for training.
25832605
When creating transformation for BigQuery Struct column, the column
2584-
should be flattened using "." as the delimiter.
2606+
should be flattened using "." as the delimiter. Only columns with no child
2607+
should have a transformation.
25852608
If an input column has no transformations on it, such a column is
25862609
ignored by the training, except for the targetColumn, which should have
25872610
no transformations defined on.
2611+
Only one of column_transformations or column_specs should be passed.
2612+
Consider using column_specs as column_transformations will be deprecated eventually.
25882613
optimization_objective_recall_value (float):
25892614
Optional. Required when maximize-precision-at-recall optimizationObjective was
25902615
picked, represents the recall value at which the optimization is done.
@@ -2628,6 +2653,9 @@ def __init__(
26282653
If set, the trained Model will be secured by this key.
26292654
26302655
Overrides encryption_spec_key_name set in aiplatform.init.
2656+
2657+
Raises:
2658+
ValueError: When both column_transforations and column_specs were passed
26312659
"""
26322660
super().__init__(
26332661
display_name=display_name,
@@ -2637,7 +2665,26 @@ def __init__(
26372665
training_encryption_spec_key_name=training_encryption_spec_key_name,
26382666
model_encryption_spec_key_name=model_encryption_spec_key_name,
26392667
)
2640-
self._column_transformations = column_transformations
2668+
# user populated transformations
2669+
if column_transformations is not None and column_specs is not None:
2670+
raise ValueError(
2671+
"Both column_transformations and column_specs were passed. Only one is allowed."
2672+
)
2673+
if column_transformations is not None:
2674+
self._column_transformations = column_transformations
2675+
warnings.simplefilter("always", DeprecationWarning)
2676+
warnings.warn(
2677+
"consider using column_specs instead. column_transformations will be deprecated in the future.",
2678+
DeprecationWarning,
2679+
stacklevel=2,
2680+
)
2681+
elif column_specs is not None:
2682+
self._column_transformations = [
2683+
{transformation: {"column_name": column_name}}
2684+
for column_name, transformation in column_specs.items()
2685+
]
2686+
else:
2687+
self._column_transformations = None
26412688
self._optimization_objective = optimization_objective
26422689
self._optimization_prediction_type = optimization_prediction_type
26432690
self._optimization_objective_recall_value = optimization_objective_recall_value
@@ -2860,6 +2907,7 @@ def _run(
28602907

28612908
training_task_definition = schema.training_job.definition.automl_tabular
28622909

2910+
# auto-populate transformations
28632911
if self._column_transformations is None:
28642912
_LOGGER.info(
28652913
"No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
@@ -2870,21 +2918,19 @@ def _run(
28702918
for column_name in dataset.column_names
28712919
if column_name != target_column
28722920
]
2873-
column_transformations = [
2921+
self._column_transformations = [
28742922
{"auto": {"column_name": column_name}} for column_name in column_names
28752923
]
28762924

28772925
_LOGGER.info(
28782926
"The column transformation of type 'auto' was set for the following columns: %s."
28792927
% column_names
28802928
)
2881-
else:
2882-
column_transformations = self._column_transformations
28832929

28842930
training_task_inputs_dict = {
28852931
# required inputs
28862932
"targetColumn": target_column,
2887-
"transformations": column_transformations,
2933+
"transformations": self._column_transformations,
28882934
"trainBudgetMilliNodeHours": budget_milli_node_hours,
28892935
# optional inputs
28902936
"weightColumnName": weight_column,
@@ -2935,6 +2981,44 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
29352981
"""
29362982
self._additional_experiments.extend(additional_experiments)
29372983

2984+
@staticmethod
2985+
def get_auto_column_specs(
2986+
dataset: datasets.TabularDataset, target_column: str,
2987+
) -> Dict[str, str]:
2988+
"""Returns a dict with all non-target columns as keys and 'auto' as values.
2989+
2990+
Example usage:
2991+
2992+
column_specs = training_jobs.AutoMLTabularTrainingJob.get_auto_column_specs(
2993+
dataset=my_dataset,
2994+
target_column="my_target_column",
2995+
)
2996+
2997+
Args:
2998+
dataset (datasets.TabularDataset):
2999+
Required. Intended dataset.
3000+
target_column(str):
3001+
Required. Intended target column.
3002+
Returns:
3003+
Dict[str, str]
3004+
Column names as keys and 'auto' as values
3005+
"""
3006+
column_names = [
3007+
column for column in dataset.column_names if column != target_column
3008+
]
3009+
column_specs = {column: "auto" for column in column_names}
3010+
return column_specs
3011+
3012+
class column_data_types:
3013+
AUTO = "auto"
3014+
NUMERIC = "numeric"
3015+
CATEGORICAL = "categorical"
3016+
TIMESTAMP = "timestamp"
3017+
TEXT = "text"
3018+
REPEATED_NUMERIC = "repeated_numeric"
3019+
REPEATED_CATEGORICAL = "repeated_categorical"
3020+
REPEATED_TEXT = "repeated_text"
3021+
29383022

29393023
class AutoMLForecastingTrainingJob(_TrainingJob):
29403024
_supported_training_schemas = (schema.training_job.definition.automl_forecasting,)

0 commit comments

Comments
 (0)