Skip to content

Commit 1a13577

Browse files
ivanmkcsasha-gitg
andauthored
feat: add filter and timestamp splits (#627)
* Fixed splits * Fixed docstrings * Fix test bug * Ran linter * Fixed FractionSplit and AutoMLVideo FilterSplit issues * Added warning for incomplete filter splits * Fixed AutoMLVideo tests * Fixed type * Moved annotation_schema_uri * Tweaked docstrings Co-authored-by: sasha-gitg <[email protected]>
1 parent 74f81e6 commit 1a13577

7 files changed

+2115
-556
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 943 additions & 238 deletions
Large diffs are not rendered by default.

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,11 @@
103103
_TEST_DATASET_NAME = "test-dataset-name"
104104

105105
_TEST_MODEL_DISPLAY_NAME = "model-display-name"
106+
106107
_TEST_LABELS = {"key": "value"}
107108
_TEST_MODEL_LABELS = {"model_key": "model_value"}
108-
_TEST_TRAINING_FRACTION_SPLIT = 0.8
109-
_TEST_VALIDATION_FRACTION_SPLIT = 0.1
110-
_TEST_TEST_FRACTION_SPLIT = 0.1
111-
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"
112109

113-
_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz"
110+
_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split"
114111

115112
_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345"
116113

@@ -261,18 +258,11 @@ def test_run_call_pipeline_service_create(
261258
if not sync:
262259
model_from_job.wait()
263260

264-
true_fraction_split = gca_training_pipeline.FractionSplit(
265-
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
266-
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
267-
test_fraction=_TEST_TEST_FRACTION_SPLIT,
268-
)
269-
270261
true_managed_model = gca_model.Model(
271262
display_name=_TEST_MODEL_DISPLAY_NAME, labels=_TEST_MODEL_LABELS
272263
)
273264

274265
true_input_data_config = gca_training_pipeline.InputDataConfig(
275-
fraction_split=true_fraction_split,
276266
predefined_split=gca_training_pipeline.PredefinedSplit(
277267
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
278268
),
@@ -348,19 +338,12 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
348338
if not sync:
349339
model_from_job.wait()
350340

351-
true_fraction_split = gca_training_pipeline.FractionSplit(
352-
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
353-
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
354-
test_fraction=_TEST_TEST_FRACTION_SPLIT,
355-
)
356-
357341
# Test that if defaults to the job display name
358342
true_managed_model = gca_model.Model(
359343
display_name=_TEST_DISPLAY_NAME, labels=_TEST_LABELS,
360344
)
361345

362346
true_input_data_config = gca_training_pipeline.InputDataConfig(
363-
fraction_split=true_fraction_split,
364347
dataset_id=mock_dataset_time_series.name,
365348
)
366349

@@ -422,17 +405,10 @@ def test_run_call_pipeline_if_set_additional_experiments(
422405
if not sync:
423406
model_from_job.wait()
424407

425-
true_fraction_split = gca_training_pipeline.FractionSplit(
426-
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
427-
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
428-
test_fraction=_TEST_TEST_FRACTION_SPLIT,
429-
)
430-
431408
# Test that if defaults to the job display name
432409
true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME)
433410

434411
true_input_data_config = gca_training_pipeline.InputDataConfig(
435-
fraction_split=true_fraction_split,
436412
dataset_id=mock_dataset_time_series.name,
437413
)
438414

0 commit comments

Comments
 (0)