Skip to content

Commit 7ae28b8

Browse files
authored
feat: expose env var in cust training class run func args (#366)
1 parent d50d26d commit 7ae28b8

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,7 @@ def run(
18051805
service_account: Optional[str] = None,
18061806
bigquery_destination: Optional[str] = None,
18071807
args: Optional[List[Union[str, float, int]]] = None,
1808+
environment_variables: Optional[Dict[str, str]] = None,
18081809
replica_count: int = 0,
18091810
machine_type: str = "n1-standard-4",
18101811
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
@@ -1880,6 +1881,13 @@ def run(
18801881
base_output_dir (str):
18811882
GCS output directory of job. If not provided a
18821883
timestamped directory in the staging directory will be used.
1884+
1885+
AI Platform sets the following environment variables when it runs your training code:
1886+
1887+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. <base_output_dir>/model/
1888+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. <base_output_dir>/checkpoints/
1889+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. <base_output_dir>/logs/
1890+
18831891
service_account (str):
18841892
Specifies the service account for workload run-as account.
18851893
Users submitting jobs must have act-as permission on this run-as account.
@@ -1900,6 +1908,16 @@ def run(
19001908
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
19011909
args (List[Unions[str, int, float]]):
19021910
Command line arguments to be passed to the Python script.
1911+
environment_variables (Dict[str, str]):
1912+
Environment variables to be passed to the container.
1913+
Should be a dictionary where keys are environment variable names
1914+
and values are environment variable values for those names.
1915+
At most 10 environment variables can be specified.
1916+
The Name of the environment variable must be unique.
1917+
1918+
environment_variables = {
1919+
'MY_KEY': 'MY_VALUE'
1920+
}
19031921
replica_count (int):
19041922
The number of worker replicas. If replica count = 1 then one chief
19051923
replica will be provisioned. If replica_count > 1 the remainder will be
@@ -1960,6 +1978,7 @@ def run(
19601978
worker_pool_specs=worker_pool_specs,
19611979
managed_model=managed_model,
19621980
args=args,
1981+
environment_variables=environment_variables,
19631982
base_output_dir=base_output_dir,
19641983
service_account=service_account,
19651984
bigquery_destination=bigquery_destination,
@@ -1986,6 +2005,7 @@ def _run(
19862005
worker_pool_specs: _DistributedTrainingSpec,
19872006
managed_model: Optional[gca_model.Model] = None,
19882007
args: Optional[List[Union[str, float, int]]] = None,
2008+
environment_variables: Optional[Dict[str, str]] = None,
19892009
base_output_dir: Optional[str] = None,
19902010
service_account: Optional[str] = None,
19912011
bigquery_destination: Optional[str] = None,
@@ -2018,9 +2038,26 @@ def _run(
20182038
Model proto if this script produces a Managed Model.
20192039
args (List[Unions[str, int, float]]):
20202040
Command line arguments to be passed to the Python script.
2041+
environment_variables (Dict[str, str]):
2042+
Environment variables to be passed to the container.
2043+
Should be a dictionary where keys are environment variable names
2044+
and values are environment variable values for those names.
2045+
At most 10 environment variables can be specified.
2046+
The Name of the environment variable must be unique.
2047+
2048+
environment_variables = {
2049+
'MY_KEY': 'MY_VALUE'
2050+
}
20212051
base_output_dir (str):
20222052
GCS output directory of job. If not provided a
20232053
timestamped directory in the staging directory will be used.
2054+
2055+
AI Platform sets the following environment variables when it runs your training code:
2056+
2057+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. <base_output_dir>/model/
2058+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. <base_output_dir>/checkpoints/
2059+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. <base_output_dir>/logs/
2060+
20242061
service_account (str):
20252062
Specifies the service account for workload run-as account.
20262063
Users submitting jobs must have act-as permission on this run-as account.
@@ -2083,6 +2120,9 @@ def _run(
20832120
if args:
20842121
spec["pythonPackageSpec"]["args"] = args
20852122

2123+
if environment_variables:
2124+
spec["pythonPackageSpec"]["env"] = environment_variables
2125+
20862126
(
20872127
training_task_inputs,
20882128
base_output_dir,
@@ -2334,6 +2374,7 @@ def run(
23342374
service_account: Optional[str] = None,
23352375
bigquery_destination: Optional[str] = None,
23362376
args: Optional[List[Union[str, float, int]]] = None,
2377+
environment_variables: Optional[Dict[str, str]] = None,
23372378
replica_count: int = 0,
23382379
machine_type: str = "n1-standard-4",
23392380
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
@@ -2402,6 +2443,13 @@ def run(
24022443
base_output_dir (str):
24032444
GCS output directory of job. If not provided a
24042445
timestamped directory in the staging directory will be used.
2446+
2447+
AI Platform sets the following environment variables when it runs your training code:
2448+
2449+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. <base_output_dir>/model/
2450+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. <base_output_dir>/checkpoints/
2451+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. <base_output_dir>/logs/
2452+
24052453
service_account (str):
24062454
Specifies the service account for workload run-as account.
24072455
Users submitting jobs must have act-as permission on this run-as account.
@@ -2422,6 +2470,16 @@ def run(
24222470
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
24232471
args (List[Unions[str, int, float]]):
24242472
Command line arguments to be passed to the Python script.
2473+
environment_variables (Dict[str, str]):
2474+
Environment variables to be passed to the container.
2475+
Should be a dictionary where keys are environment variable names
2476+
and values are environment variable values for those names.
2477+
At most 10 environment variables can be specified.
2478+
The Name of the environment variable must be unique.
2479+
2480+
environment_variables = {
2481+
'MY_KEY': 'MY_VALUE'
2482+
}
24252483
replica_count (int):
24262484
The number of worker replicas. If replica count = 1 then one chief
24272485
replica will be provisioned. If replica_count > 1 the remainder will be
@@ -2481,6 +2539,7 @@ def run(
24812539
worker_pool_specs=worker_pool_specs,
24822540
managed_model=managed_model,
24832541
args=args,
2542+
environment_variables=environment_variables,
24842543
base_output_dir=base_output_dir,
24852544
service_account=service_account,
24862545
bigquery_destination=bigquery_destination,
@@ -2506,6 +2565,7 @@ def _run(
25062565
worker_pool_specs: _DistributedTrainingSpec,
25072566
managed_model: Optional[gca_model.Model] = None,
25082567
args: Optional[List[Union[str, float, int]]] = None,
2568+
environment_variables: Optional[Dict[str, str]] = None,
25092569
base_output_dir: Optional[str] = None,
25102570
service_account: Optional[str] = None,
25112571
bigquery_destination: Optional[str] = None,
@@ -2535,9 +2595,26 @@ def _run(
25352595
Model proto if this script produces a Managed Model.
25362596
args (List[Unions[str, int, float]]):
25372597
Command line arguments to be passed to the Python script.
2598+
environment_variables (Dict[str, str]):
2599+
Environment variables to be passed to the container.
2600+
Should be a dictionary where keys are environment variable names
2601+
and values are environment variable values for those names.
2602+
At most 10 environment variables can be specified.
2603+
The Name of the environment variable must be unique.
2604+
2605+
environment_variables = {
2606+
'MY_KEY': 'MY_VALUE'
2607+
}
25382608
base_output_dir (str):
25392609
GCS output directory of job. If not provided a
25402610
timestamped directory in the staging directory will be used.
2611+
2612+
AI Platform sets the following environment variables when it runs your training code:
2613+
2614+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. <base_output_dir>/model/
2615+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. <base_output_dir>/checkpoints/
2616+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. <base_output_dir>/logs/
2617+
25412618
service_account (str):
25422619
Specifies the service account for workload run-as account.
25432620
Users submitting jobs must have act-as permission on this run-as account.
@@ -2593,6 +2670,9 @@ def _run(
25932670
if args:
25942671
spec["containerSpec"]["args"] = args
25952672

2673+
if environment_variables:
2674+
spec["containerSpec"]["env"] = environment_variables
2675+
25962676
(
25972677
training_task_inputs,
25982678
base_output_dir,
@@ -3625,6 +3705,7 @@ def run(
36253705
service_account: Optional[str] = None,
36263706
bigquery_destination: Optional[str] = None,
36273707
args: Optional[List[Union[str, float, int]]] = None,
3708+
environment_variables: Optional[Dict[str, str]] = None,
36283709
replica_count: int = 0,
36293710
machine_type: str = "n1-standard-4",
36303711
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
@@ -3693,6 +3774,13 @@ def run(
36933774
base_output_dir (str):
36943775
GCS output directory of job. If not provided a
36953776
timestamped directory in the staging directory will be used.
3777+
3778+
AI Platform sets the following environment variables when it runs your training code:
3779+
3780+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. <base_output_dir>/model/
3781+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. <base_output_dir>/checkpoints/
3782+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. <base_output_dir>/logs/
3783+
36963784
service_account (str):
36973785
Specifies the service account for workload run-as account.
36983786
Users submitting jobs must have act-as permission on this run-as account.
@@ -3713,6 +3801,16 @@ def run(
37133801
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
37143802
args (List[Unions[str, int, float]]):
37153803
Command line arguments to be passed to the Python script.
3804+
environment_variables (Dict[str, str]):
3805+
Environment variables to be passed to the container.
3806+
Should be a dictionary where keys are environment variable names
3807+
and values are environment variable values for those names.
3808+
At most 10 environment variables can be specified.
3809+
The Name of the environment variable must be unique.
3810+
3811+
environment_variables = {
3812+
'MY_KEY': 'MY_VALUE'
3813+
}
37163814
replica_count (int):
37173815
The number of worker replicas. If replica count = 1 then one chief
37183816
replica will be provisioned. If replica_count > 1 the remainder will be
@@ -3767,6 +3865,7 @@ def run(
37673865
worker_pool_specs=worker_pool_specs,
37683866
managed_model=managed_model,
37693867
args=args,
3868+
environment_variables=environment_variables,
37703869
base_output_dir=base_output_dir,
37713870
service_account=service_account,
37723871
training_fraction_split=training_fraction_split,
@@ -3792,6 +3891,7 @@ def _run(
37923891
worker_pool_specs: _DistributedTrainingSpec,
37933892
managed_model: Optional[gca_model.Model] = None,
37943893
args: Optional[List[Union[str, float, int]]] = None,
3894+
environment_variables: Optional[Dict[str, str]] = None,
37953895
base_output_dir: Optional[str] = None,
37963896
service_account: Optional[str] = None,
37973897
training_fraction_split: float = 0.8,
@@ -3822,9 +3922,26 @@ def _run(
38223922
Model proto if this script produces a Managed Model.
38233923
args (List[Unions[str, int, float]]):
38243924
Command line arguments to be passed to the Python script.
3925+
environment_variables (Dict[str, str]):
3926+
Environment variables to be passed to the container.
3927+
Should be a dictionary where keys are environment variable names
3928+
and values are environment variable values for those names.
3929+
At most 10 environment variables can be specified.
3930+
The Name of the environment variable must be unique.
3931+
3932+
environment_variables = {
3933+
'MY_KEY': 'MY_VALUE'
3934+
}
38253935
base_output_dir (str):
38263936
GCS output directory of job. If not provided a
38273937
timestamped directory in the staging directory will be used.
3938+
3939+
AI Platform sets the following environment variables when it runs your training code:
3940+
3941+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. <base_output_dir>/model/
3942+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. <base_output_dir>/checkpoints/
3943+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. <base_output_dir>/logs/
3944+
38283945
service_account (str):
38293946
Specifies the service account for workload run-as account.
38303947
Users submitting jobs must have act-as permission on this run-as account.
@@ -3866,6 +3983,9 @@ def _run(
38663983
if args:
38673984
spec["pythonPackageSpec"]["args"] = args
38683985

3986+
if environment_variables:
3987+
spec["pythonPackageSpec"]["env"] = environment_variables
3988+
38693989
(
38703990
training_task_inputs,
38713991
base_output_dir,

0 commit comments

Comments
 (0)