Skip to content

Commit 69fc7fd

Browse files
ivanmkcIvan Cheung
andauthored
feat: Added tabular forecasting samples (#128)
* Added predict, get_model_evaluation and create_training_pipeline samples for AutoML Forecasting * Added param handlers * Added headers manually * fix: Improved forecasting sample * Added forecasting test * Added tests for predict and get_model_evaluation * fix: Fixed create_training_pipeline_sample * feat: Added list_model_evaluations_tabular_forecasting_sample and test, fixed get_model_evaluation_tabular_forecasting_sample, and fixed create_training_pipeline_tabular_forecasting_sample * fix: Reverted back to generated BUILD_SPECIFIC_GCLOUD_PROJECT * fix: Fixed name of test * fix: Fixed lint errors * fix: Fixed assertion * fix: Removed predict samples * Consolidated samples * fix: Removed list_model_evaluations_tabular_forecasting * fix: tweaks Co-authored-by: Ivan Cheung <[email protected]>
1 parent 624a08d commit 69fc7fd

File tree

5 files changed

+264
-0
lines changed

5 files changed

+264
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
17+
def make_parent(parent: str) -> str:
18+
parent = parent
19+
20+
return parent
21+
22+
23+
def make_training_pipeline(
24+
display_name: str,
25+
dataset_id: str,
26+
model_display_name: str,
27+
target_column: str,
28+
time_series_identifier_column: str,
29+
time_column: str,
30+
static_columns: str,
31+
time_variant_past_only_columns: str,
32+
time_variant_past_and_future_columns: str,
33+
forecast_window_end: int,
34+
) -> google.cloud.aiplatform_v1alpha1.types.training_pipeline.TrainingPipeline:
35+
# set the columns used for training and their data types
36+
transformations = [
37+
{"auto": {"column_name": "date"}},
38+
{"auto": {"column_name": "state_name"}},
39+
{"auto": {"column_name": "county_fips_code"}},
40+
{"auto": {"column_name": "confirmed_cases"}},
41+
{"auto": {"column_name": "deaths"}},
42+
]
43+
44+
period = {"unit": "day", "quantity": 1}
45+
46+
# the inputs should be formatted according to the training_task_definition yaml file
47+
training_task_inputs_dict = {
48+
# required inputs
49+
"targetColumn": target_column,
50+
"timeSeriesIdentifierColumn": time_series_identifier_column,
51+
"timeColumn": time_column,
52+
"transformations": transformations,
53+
"period": period,
54+
"optimizationObjective": "minimize-rmse",
55+
"trainBudgetMilliNodeHours": 8000,
56+
"staticColumns": static_columns,
57+
"timeVariantPastOnlyColumns": time_variant_past_only_columns,
58+
"timeVariantPastAndFutureColumns": time_variant_past_and_future_columns,
59+
"forecastWindowEnd": forecast_window_end,
60+
}
61+
62+
training_task_inputs = to_protobuf_value(training_task_inputs_dict)
63+
64+
training_pipeline = {
65+
"display_name": display_name,
66+
"training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml",
67+
"training_task_inputs": training_task_inputs,
68+
"input_data_config": {
69+
"dataset_id": dataset_id,
70+
"fraction_split": {
71+
"training_fraction": 0.8,
72+
"validation_fraction": 0.1,
73+
"test_fraction": 0.1,
74+
},
75+
},
76+
"model_to_upload": {"display_name": model_display_name},
77+
}
78+
79+
return training_pipeline
80+

.sample_configs/process_configs.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ create_batch_prediction_job_custom_image_explain_sample: {}
1919
create_batch_prediction_job_custom_tabular_explain_sample: {}
2020
create_batch_prediction_job_sample: {}
2121
create_batch_prediction_job_tabular_explain_sample: {}
22+
create_batch_prediction_job_tabular_forecasting_sample: {}
2223
create_batch_prediction_job_text_classification_sample: {}
2324
create_batch_prediction_job_text_entity_extraction_sample: {}
2425
create_batch_prediction_job_text_sentiment_analysis_sample: {}
@@ -77,6 +78,7 @@ create_training_pipeline_image_object_detection_sample:
7778
training_task_inputs_dict: trainingjob.definition.AutoMlImageObjectDetectionInputs
7879
create_training_pipeline_sample: {}
7980
create_training_pipeline_tabular_classification_sample: {}
81+
create_training_pipeline_tabular_forecasting_sample: {}
8082
create_training_pipeline_tabular_regression_sample: {}
8183
create_training_pipeline_text_classification_sample:
8284
schema_types:
@@ -168,6 +170,7 @@ get_model_evaluation_sample:
168170
- model_explanation
169171
get_model_evaluation_slice_sample: {}
170172
get_model_evaluation_tabular_classification_sample: {}
173+
get_model_evaluation_tabular_forecasting_sample: {}
171174
get_model_evaluation_tabular_regression_sample: {}
172175
get_model_evaluation_text_classification_sample:
173176
skip:
@@ -232,6 +235,7 @@ list_endpoints_sample: {}
232235
list_hyperparameter_tuning_jobs_sample: {}
233236
list_model_evaluation_slices_sample: {}
234237
list_model_evaluations_sample: {}
238+
list_model_evaluations_tabular_forecasting_sample: {}
235239
list_models_sample: {}
236240
list_specialist_pools_sample: {}
237241
list_training_pipelines_sample: {}
@@ -274,6 +278,7 @@ predict_tabular_classification_sample:
274278
comments:
275279
predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/tables_classification.yaml
276280
for the format of the predictions.
281+
predict_tabular_forecasting_sample: {}
277282
predict_tabular_regression_sample:
278283
api_endpoint: us-central1-prediction-aiplatform.googleapis.com
279284
max_depth: 1

.sample_configs/variants.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ create_batch_prediction_job:
2222
- custom_image_explain
2323
- custom_tabular_explain
2424
- tabular_explain
25+
- tabular_forecasting
2526
- text_classification
2627
- text_entity_extraction
2728
- text_sentiment_analysis
@@ -59,6 +60,7 @@ create_training_pipeline:
5960
- image_classification
6061
- image_object_detection
6162
- tabular_classification
63+
- tabular_forecasting
6264
- tabular_regression
6365
- text_classification
6466
- text_entity_extraction
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_create_training_pipeline_tabular_forecasting_sample]
16+
from google.cloud import aiplatform
17+
from google.protobuf import json_format
18+
from google.protobuf.struct_pb2 import Value
19+
20+
21+
def create_training_pipeline_tabular_forecasting_sample(
22+
project: str,
23+
display_name: str,
24+
dataset_id: str,
25+
model_display_name: str,
26+
target_column: str,
27+
time_series_identifier_column: str,
28+
time_column: str,
29+
static_columns: str,
30+
time_variant_past_only_columns: str,
31+
time_variant_past_and_future_columns: str,
32+
forecast_window_end: int,
33+
location: str = "us-central1",
34+
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
35+
):
36+
client_options = {"api_endpoint": api_endpoint}
37+
# Initialize client that will be used to create and send requests.
38+
# This client only needs to be created once, and can be reused for multiple requests.
39+
client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
40+
# set the columns used for training and their data types
41+
transformations = [
42+
{"auto": {"column_name": "date"}},
43+
{"auto": {"column_name": "state_name"}},
44+
{"auto": {"column_name": "county_fips_code"}},
45+
{"auto": {"column_name": "confirmed_cases"}},
46+
{"auto": {"column_name": "deaths"}},
47+
]
48+
49+
period = {"unit": "day", "quantity": 1}
50+
51+
# the inputs should be formatted according to the training_task_definition yaml file
52+
training_task_inputs_dict = {
53+
# required inputs
54+
"targetColumn": target_column,
55+
"timeSeriesIdentifierColumn": time_series_identifier_column,
56+
"timeColumn": time_column,
57+
"transformations": transformations,
58+
"period": period,
59+
"optimizationObjective": "minimize-rmse",
60+
"trainBudgetMilliNodeHours": 8000,
61+
"staticColumns": static_columns,
62+
"timeVariantPastOnlyColumns": time_variant_past_only_columns,
63+
"timeVariantPastAndFutureColumns": time_variant_past_and_future_columns,
64+
"forecastWindowEnd": forecast_window_end,
65+
}
66+
67+
training_task_inputs = json_format.ParseDict(training_task_inputs_dict, Value())
68+
69+
training_pipeline = {
70+
"display_name": display_name,
71+
"training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml",
72+
"training_task_inputs": training_task_inputs,
73+
"input_data_config": {
74+
"dataset_id": dataset_id,
75+
"fraction_split": {
76+
"training_fraction": 0.8,
77+
"validation_fraction": 0.1,
78+
"test_fraction": 0.1,
79+
},
80+
},
81+
"model_to_upload": {"display_name": model_display_name},
82+
}
83+
parent = f"projects/{project}/locations/{location}"
84+
response = client.create_training_pipeline(
85+
parent=parent, training_pipeline=training_pipeline
86+
)
87+
print("response:", response)
88+
89+
90+
# [END aiplatform_create_training_pipeline_tabular_forecasting_sample]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from uuid import uuid4
17+
18+
from google.cloud import aiplatform
19+
import pytest
20+
21+
import cancel_training_pipeline_sample
22+
import create_training_pipeline_tabular_forecasting_sample
23+
import delete_training_pipeline_sample
24+
import helpers
25+
26+
PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
27+
DATASET_ID = "3003302817130610688" # COVID Dataset
28+
DISPLAY_NAME = f"temp_create_training_pipeline_test_{uuid4()}"
29+
TARGET_COLUMN = "deaths"
30+
PREDICTION_TYPE = "forecasting"
31+
32+
33+
@pytest.fixture
34+
def shared_state():
35+
state = {}
36+
yield state
37+
38+
39+
@pytest.fixture(scope="function", autouse=True)
40+
def teardown(shared_state):
41+
yield
42+
43+
training_pipeline_id = shared_state["training_pipeline_name"].split("/")[-1]
44+
45+
# Stop the training pipeline
46+
cancel_training_pipeline_sample.cancel_training_pipeline_sample(
47+
project=PROJECT_ID, training_pipeline_id=training_pipeline_id
48+
)
49+
50+
client_options = {"api_endpoint": "us-central1-aiplatform.googleapis.com"}
51+
pipeline_client = aiplatform.gapic.PipelineServiceClient(
52+
client_options=client_options
53+
)
54+
55+
# Waiting for training pipeline to be in CANCELLED state
56+
helpers.wait_for_job_state(
57+
get_job_method=pipeline_client.get_training_pipeline,
58+
name=shared_state["training_pipeline_name"],
59+
)
60+
61+
# Delete the training pipeline
62+
delete_training_pipeline_sample.delete_training_pipeline_sample(
63+
project=PROJECT_ID, training_pipeline_id=training_pipeline_id
64+
)
65+
66+
67+
def test_ucaip_generated_create_training_pipeline_sample(capsys, shared_state):
68+
69+
create_training_pipeline_tabular_forecasting_sample.create_training_pipeline_tabular_forecasting_sample(
70+
project=PROJECT_ID,
71+
display_name=DISPLAY_NAME,
72+
dataset_id=DATASET_ID,
73+
model_display_name="permanent_tabular_forecasting_model",
74+
target_column=TARGET_COLUMN,
75+
time_series_identifier_column="county",
76+
time_column="date",
77+
static_columns=["state_name"],
78+
time_variant_past_only_columns=["deaths"],
79+
time_variant_past_and_future_columns=["date"],
80+
forecast_window_end=10,
81+
)
82+
83+
out, _ = capsys.readouterr()
84+
assert "response:" in out
85+
86+
# Save resource name of the newly created training pipeline
87+
shared_state["training_pipeline_name"] = helpers.get_name(out)

0 commit comments

Comments
 (0)