Skip to content

Commit 6bc4c61

Browse files
authored
fix: env formatiing (#379)
1 parent 8945865 commit 6bc4c61

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,7 +2121,10 @@ def _run(
21212121
spec["pythonPackageSpec"]["args"] = args
21222122

21232123
if environment_variables:
2124-
spec["pythonPackageSpec"]["env"] = environment_variables
2124+
spec["pythonPackageSpec"]["env"] = [
2125+
{"name": key, "value": value}
2126+
for key, value in environment_variables.items()
2127+
]
21252128

21262129
(
21272130
training_task_inputs,
@@ -2671,7 +2674,10 @@ def _run(
26712674
spec["containerSpec"]["args"] = args
26722675

26732676
if environment_variables:
2674-
spec["containerSpec"]["env"] = environment_variables
2677+
spec["containerSpec"]["env"] = [
2678+
{"name": key, "value": value}
2679+
for key, value in environment_variables.items()
2680+
]
26752681

26762682
(
26772683
training_task_inputs,
@@ -3734,7 +3740,7 @@ def run(
37343740
Args:
37353741
dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset,]):
37363742
AI Platform to fit this training against. Custom training script should
3737-
retrieve datasets through passed in environement variables uris:
3743+
retrieve datasets through passed in environment variables uris:
37383744
37393745
os.environ["AIP_TRAINING_DATA_URI"]
37403746
os.environ["AIP_VALIDATION_DATA_URI"]
@@ -3984,7 +3990,10 @@ def _run(
39843990
spec["pythonPackageSpec"]["args"] = args
39853991

39863992
if environment_variables:
3987-
spec["pythonPackageSpec"]["env"] = environment_variables
3993+
spec["pythonPackageSpec"]["env"] = [
3994+
{"name": key, "value": value}
3995+
for key, value in environment_variables.items()
3996+
]
39883997

39893998
(
39903999
training_task_inputs,

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
622622
)
623623

624624
true_args = _TEST_RUN_ARGS
625-
true_env = _TEST_ENVIRONMENT_VARIABLES
625+
true_env = [
626+
{"name": key, "value": value}
627+
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
628+
]
626629

627630
true_worker_pool_spec = {
628631
"replicaCount": _TEST_REPLICA_COUNT,
@@ -777,7 +780,10 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
777780
model_from_job.wait()
778781

779782
true_args = _TEST_RUN_ARGS
780-
true_env = _TEST_ENVIRONMENT_VARIABLES
783+
true_env = [
784+
{"name": key, "value": value}
785+
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
786+
]
781787

782788
true_worker_pool_spec = {
783789
"replicaCount": _TEST_REPLICA_COUNT,
@@ -1049,7 +1055,10 @@ def test_run_call_pipeline_service_create_with_no_dataset(
10491055
)
10501056

10511057
true_args = _TEST_RUN_ARGS
1052-
true_env = _TEST_ENVIRONMENT_VARIABLES
1058+
true_env = [
1059+
{"name": key, "value": value}
1060+
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
1061+
]
10531062

10541063
true_worker_pool_spec = {
10551064
"replicaCount": _TEST_REPLICA_COUNT,
@@ -1297,7 +1306,10 @@ def test_run_call_pipeline_service_create_distributed_training(
12971306
)
12981307

12991308
true_args = _TEST_RUN_ARGS
1300-
true_env = _TEST_ENVIRONMENT_VARIABLES
1309+
true_env = [
1310+
{"name": key, "value": value}
1311+
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
1312+
]
13011313

13021314
true_worker_pool_spec = [
13031315
{
@@ -1763,7 +1775,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
17631775
model_from_job.wait()
17641776

17651777
true_args = _TEST_RUN_ARGS
1766-
true_env = _TEST_ENVIRONMENT_VARIABLES
1778+
true_env = [
1779+
{"name": key, "value": value}
1780+
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
1781+
]
17671782

17681783
true_worker_pool_spec = {
17691784
"replicaCount": _TEST_REPLICA_COUNT,
@@ -2972,7 +2987,10 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
29722987
model_from_job.wait()
29732988

29742989
true_args = _TEST_RUN_ARGS
2975-
true_env = _TEST_ENVIRONMENT_VARIABLES
2990+
true_env = [
2991+
{"name": key, "value": value}
2992+
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
2993+
]
29762994

29772995
true_worker_pool_spec = {
29782996
"replicaCount": _TEST_REPLICA_COUNT,

0 commit comments

Comments
 (0)