Skip to content

Commit fda942f

Browse files
fix: add param for multi-label per user's feedback (#887)
* fix: add param for multi-label per user's feedback * fix: indentation * test: update assert for new params * lint: remove trailing whitespace
1 parent 67fa1f1 commit fda942f

2 files changed

+9
-2
lines changed

samples/model-builder/create_training_pipeline_image_classification_sample.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def create_training_pipeline_image_classification_sample(
2424
display_name: str,
2525
dataset_id: int,
2626
model_display_name: Optional[str] = None,
27+
multi_label: bool = False,
2728
training_fraction_split: float = 0.8,
2829
validation_fraction_split: float = 0.1,
2930
test_fraction_split: float = 0.1,
@@ -33,7 +34,11 @@ def create_training_pipeline_image_classification_sample(
3334
):
3435
aiplatform.init(project=project, location=location)
3536

36-
job = aiplatform.AutoMLImageTrainingJob(display_name=display_name)
37+
job = aiplatform.AutoMLImageTrainingJob(
38+
display_name=display_name,
39+
prediction_type='classification',
40+
multi_label=multi_label
41+
)
3742

3843
my_image_ds = aiplatform.ImageDataset(dataset_id)
3944

samples/model-builder/create_training_pipeline_image_classification_sample_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def test_create_training_pipeline_image_classification_sample(
4444
project=constants.PROJECT, location=constants.LOCATION
4545
)
4646
mock_get_automl_image_training_job.assert_called_once_with(
47-
display_name=constants.DISPLAY_NAME
47+
display_name=constants.DISPLAY_NAME,
48+
multi_label=False,
49+
prediction_type='classification'
4850
)
4951
mock_run_automl_image_training_job.assert_called_once_with(
5052
dataset=mock_image_dataset,

0 commit comments

Comments
 (0)