Skip to content

Commit 2f138d1

Browse files
authored
feat: expose base_output_dir for custom job (#586)
1 parent 2a6b0a3 commit 2f138d1

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

google/cloud/aiplatform/jobs.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
study as gca_study_compat,
6060
)
6161

62+
6263
_LOGGER = base.Logger(__name__)
6364

6465
_JOB_COMPLETE_STATES = (
@@ -930,6 +931,7 @@ def __init__(
930931
self,
931932
display_name: str,
932933
worker_pool_specs: Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]],
934+
base_output_dir: Optional[str] = None,
933935
project: Optional[str] = None,
934936
location: Optional[str] = None,
935937
credentials: Optional[auth_credentials.Credentials] = None,
@@ -977,6 +979,9 @@ def __init__(
977979
worker_pool_specs (Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]]):
978980
Required. The spec of the worker pools including machine type and Docker image.
979981
Can provided as a list of dictionaries or list of WorkerPoolSpec proto messages.
982+
base_output_dir (str):
983+
Optional. GCS output directory of job. If not provided a
984+
timestamped directory in the staging directory will be used.
980985
project (str):
981986
Optional.Project to run the custom job in. Overrides project set in aiplatform.init.
982987
location (str):
@@ -1008,12 +1013,17 @@ def __init__(
10081013
"should be set using aiplatform.init(staging_bucket='gs://my-bucket')"
10091014
)
10101015

1016+
# default directory if not given
1017+
base_output_dir = base_output_dir or utils._timestamped_gcs_dir(
1018+
staging_bucket, "aiplatform-custom-job"
1019+
)
1020+
10111021
self._gca_resource = gca_custom_job_compat.CustomJob(
10121022
display_name=display_name,
10131023
job_spec=gca_custom_job_compat.CustomJobSpec(
10141024
worker_pool_specs=worker_pool_specs,
10151025
base_output_directory=gca_io_compat.GcsDestination(
1016-
output_uri_prefix=staging_bucket
1026+
output_uri_prefix=base_output_dir
10171027
),
10181028
),
10191029
encryption_spec=initializer.global_config.get_encryption_spec(
@@ -1049,6 +1059,7 @@ def from_local_script(
10491059
machine_type: str = "n1-standard-4",
10501060
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
10511061
accelerator_count: int = 0,
1062+
base_output_dir: Optional[str] = None,
10521063
project: Optional[str] = None,
10531064
location: Optional[str] = None,
10541065
credentials: Optional[auth_credentials.Credentials] = None,
@@ -1105,6 +1116,9 @@ def from_local_script(
11051116
NVIDIA_TESLA_T4
11061117
accelerator_count (int):
11071118
Optional. The number of accelerators to attach to a worker replica.
1119+
base_output_dir (str):
1120+
Optional. GCS output directory of job. If not provided a
1121+
timestamped directory in the staging directory will be used.
11081122
project (str):
11091123
Optional. Project to run the custom job in. Overrides project set in aiplatform.init.
11101124
location (str):
@@ -1170,6 +1184,7 @@ def from_local_script(
11701184
return cls(
11711185
display_name=display_name,
11721186
worker_pool_specs=worker_pool_specs,
1187+
base_output_dir=base_output_dir,
11731188
project=project,
11741189
location=location,
11751190
credentials=credentials,

tests/unit/aiplatform/test_custom_job.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
]
7272

7373
_TEST_STAGING_BUCKET = "gs://test-staging-bucket"
74+
_TEST_BASE_OUTPUT_DIR = f"{_TEST_STAGING_BUCKET}/{_TEST_DISPLAY_NAME}"
7475

7576
# CMEK encryption
7677
_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default"
@@ -91,7 +92,7 @@
9192
job_spec=gca_custom_job_compat.CustomJobSpec(
9293
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
9394
base_output_directory=gca_io_compat.GcsDestination(
94-
output_uri_prefix=_TEST_STAGING_BUCKET
95+
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
9596
),
9697
scheduling=gca_custom_job_compat.Scheduling(
9798
timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
@@ -224,7 +225,9 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
224225
)
225226

226227
job = aiplatform.CustomJob(
227-
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
228+
display_name=_TEST_DISPLAY_NAME,
229+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
230+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
228231
)
229232

230233
job.run(
@@ -265,7 +268,9 @@ def test_run_custom_job_with_fail_raises(
265268
)
266269

267270
job = aiplatform.CustomJob(
268-
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
271+
display_name=_TEST_DISPLAY_NAME,
272+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
273+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
269274
)
270275

271276
with pytest.raises(RuntimeError) as e:
@@ -306,7 +311,9 @@ def test_run_custom_job_with_fail_at_creation(self):
306311
)
307312

308313
job = aiplatform.CustomJob(
309-
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
314+
display_name=_TEST_DISPLAY_NAME,
315+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
316+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
310317
)
311318

312319
job.run(
@@ -342,7 +349,9 @@ def test_custom_job_get_state_raises_without_run(self):
342349
)
343350

344351
job = aiplatform.CustomJob(
345-
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
352+
display_name=_TEST_DISPLAY_NAME,
353+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
354+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
346355
)
347356

348357
with pytest.raises(RuntimeError):
@@ -385,6 +394,7 @@ def test_create_from_local_script(
385394
display_name=_TEST_DISPLAY_NAME,
386395
script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME,
387396
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
397+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
388398
)
389399

390400
job.run(sync=sync)
@@ -428,7 +438,9 @@ def test_create_custom_job_with_tensorboard(
428438
)
429439

430440
job = aiplatform.CustomJob(
431-
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
441+
display_name=_TEST_DISPLAY_NAME,
442+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
443+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
432444
)
433445

434446
job.run(
@@ -454,3 +466,20 @@ def test_create_custom_job_with_tensorboard(
454466
assert (
455467
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
456468
)
469+
470+
def test_create_custom_job_without_base_output_dir(self,):
471+
472+
aiplatform.init(
473+
project=_TEST_PROJECT,
474+
location=_TEST_LOCATION,
475+
staging_bucket=_TEST_STAGING_BUCKET,
476+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
477+
)
478+
479+
job = aiplatform.CustomJob(
480+
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC,
481+
)
482+
483+
assert job.job_spec.base_output_directory.output_uri_prefix.startswith(
484+
f"{_TEST_STAGING_BUCKET}/aiplatform-custom-job"
485+
)

tests/unit/aiplatform/test_hyperparameter_tuning_job.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
5050

5151
_TEST_STAGING_BUCKET = test_custom_job._TEST_STAGING_BUCKET
52+
_TEST_BASE_OUTPUT_DIR = test_custom_job._TEST_BASE_OUTPUT_DIR
5253

5354
_TEST_HYPERPARAMETERTUNING_JOB_NAME = (
5455
f"{_TEST_PARENT}/hyperparameterTuningJobs/{_TEST_ID}"
@@ -260,6 +261,7 @@ def test_create_hyperparameter_tuning_job(
260261
custom_job = aiplatform.CustomJob(
261262
display_name=test_custom_job._TEST_DISPLAY_NAME,
262263
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
264+
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
263265
)
264266

265267
job = aiplatform.HyperparameterTuningJob(
@@ -321,6 +323,7 @@ def test_run_hyperparameter_tuning_job_with_fail_raises(
321323
custom_job = aiplatform.CustomJob(
322324
display_name=test_custom_job._TEST_DISPLAY_NAME,
323325
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
326+
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
324327
)
325328

326329
job = aiplatform.HyperparameterTuningJob(
@@ -376,6 +379,7 @@ def test_run_hyperparameter_tuning_job_with_fail_at_creation(self):
376379
custom_job = aiplatform.CustomJob(
377380
display_name=test_custom_job._TEST_DISPLAY_NAME,
378381
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
382+
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
379383
)
380384

381385
job = aiplatform.HyperparameterTuningJob(
@@ -440,6 +444,7 @@ def test_hyperparameter_tuning_job_get_state_raises_without_run(self):
440444
custom_job = aiplatform.CustomJob(
441445
display_name=test_custom_job._TEST_DISPLAY_NAME,
442446
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
447+
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
443448
)
444449

445450
job = aiplatform.HyperparameterTuningJob(
@@ -497,6 +502,7 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
497502
custom_job = aiplatform.CustomJob(
498503
display_name=test_custom_job._TEST_DISPLAY_NAME,
499504
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
505+
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
500506
)
501507

502508
job = aiplatform.HyperparameterTuningJob(

0 commit comments

Comments
 (0)