Skip to content

Commit 45401c0

Browse files
authored
fix: Add retries when polling during monitoring runs (#786)
1 parent 78879e2 commit 45401c0

20 files changed

+288
-111
lines changed

google/cloud/aiplatform/base.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import proto
4141

42+
from google.api_core import retry
4243
from google.api_core import operation
4344
from google.auth import credentials as auth_credentials
4445
from google.cloud.aiplatform import initializer
@@ -48,6 +49,9 @@
4849

4950
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
5051

52+
# This is the default retry callback to be used with get methods.
53+
_DEFAULT_RETRY = retry.Retry()
54+
5155

5256
class Logger:
5357
"""Logging wrapper class with high level helper methods."""
@@ -532,7 +536,9 @@ def _get_gca_resource(self, resource_name: str) -> proto.Message:
532536
location=self.location,
533537
)
534538

535-
return getattr(self.api_client, self._getter_method)(name=resource_name)
539+
return getattr(self.api_client, self._getter_method)(
540+
name=resource_name, retry=_DEFAULT_RETRY
541+
)
536542

537543
def _sync_gca_resource(self):
538544
"""Sync GAPIC service representation of client class resource."""

google/cloud/aiplatform/metadata/resource.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
)
9595

9696
self._gca_resource = getattr(self.api_client, self._getter_method)(
97-
name=full_resource_name
97+
name=full_resource_name, retry=base._DEFAULT_RETRY
9898
)
9999

100100
@property

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest import mock
44

55
from google.cloud import aiplatform
6+
from google.cloud.aiplatform import base
67
from google.cloud.aiplatform import datasets
78
from google.cloud.aiplatform import initializer
89
from google.cloud.aiplatform import schema
@@ -301,7 +302,9 @@ def test_run_call_pipeline_service_create(
301302

302303
assert job._gca_resource is mock_pipeline_service_get.return_value
303304

304-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
305+
mock_model_service_get.assert_called_once_with(
306+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
307+
)
305308

306309
assert model_from_job._gca_resource is mock_model_service_get.return_value
307310

tests/unit/aiplatform/test_automl_image_training_jobs.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from google.protobuf import struct_pb2
77

88
from google.cloud import aiplatform
9-
9+
from google.cloud.aiplatform import base
1010
from google.cloud.aiplatform import datasets
1111
from google.cloud.aiplatform import initializer
1212
from google.cloud.aiplatform import models
@@ -309,7 +309,9 @@ def test_run_call_pipeline_service_create(
309309
training_pipeline=true_training_pipeline,
310310
)
311311

312-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
312+
mock_model_service_get.assert_called_once_with(
313+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
314+
)
313315
assert job._gca_resource is mock_pipeline_service_get.return_value
314316
assert model_from_job._gca_resource is mock_model_service_get.return_value
315317
assert job.get_model()._gca_resource is mock_model_service_get.return_value

tests/unit/aiplatform/test_automl_tabular_training_jobs.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unittest import mock
44

55
from google.cloud import aiplatform
6-
6+
from google.cloud.aiplatform import base
77
from google.cloud.aiplatform import datasets
88
from google.cloud.aiplatform import initializer
99
from google.cloud.aiplatform import schema
@@ -367,7 +367,9 @@ def test_run_call_pipeline_service_create(
367367

368368
assert job._gca_resource is mock_pipeline_service_get.return_value
369369

370-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
370+
mock_model_service_get.assert_called_once_with(
371+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
372+
)
371373

372374
assert model_from_job._gca_resource is mock_model_service_get.return_value
373375

@@ -446,7 +448,9 @@ def test_run_call_pipeline_service_create_with_export_eval_data_items(
446448

447449
assert job._gca_resource is mock_pipeline_service_get.return_value
448450

449-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
451+
mock_model_service_get.assert_called_once_with(
452+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
453+
)
450454

451455
assert model_from_job._gca_resource is mock_model_service_get.return_value
452456

tests/unit/aiplatform/test_automl_text_training_jobs.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unittest import mock
44

55
from google.cloud import aiplatform
6-
6+
from google.cloud.aiplatform import base
77
from google.cloud.aiplatform import datasets
88
from google.cloud.aiplatform import initializer
99
from google.cloud.aiplatform import models
@@ -370,7 +370,9 @@ def test_run_call_pipeline_service_create_classification(
370370
training_pipeline=true_training_pipeline,
371371
)
372372

373-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
373+
mock_model_service_get.assert_called_once_with(
374+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
375+
)
374376
assert job._gca_resource is mock_pipeline_service_get.return_value
375377
assert model_from_job._gca_resource is mock_model_service_get.return_value
376378
assert job.get_model()._gca_resource is mock_model_service_get.return_value
@@ -437,7 +439,9 @@ def test_run_call_pipeline_service_create_extraction(
437439
training_pipeline=true_training_pipeline,
438440
)
439441

440-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
442+
mock_model_service_get.assert_called_once_with(
443+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
444+
)
441445
assert job._gca_resource is mock_pipeline_service_get.return_value
442446
assert model_from_job._gca_resource is mock_model_service_get.return_value
443447
assert job.get_model()._gca_resource is mock_model_service_get.return_value
@@ -505,7 +509,9 @@ def test_run_call_pipeline_service_create_sentiment(
505509
training_pipeline=true_training_pipeline,
506510
)
507511

508-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
512+
mock_model_service_get.assert_called_once_with(
513+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
514+
)
509515
assert job._gca_resource is mock_pipeline_service_get.return_value
510516
assert model_from_job._gca_resource is mock_model_service_get.return_value
511517
assert job.get_model()._gca_resource is mock_model_service_get.return_value

tests/unit/aiplatform/test_automl_video_training_jobs.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from google.protobuf import struct_pb2
77

88
from google.cloud import aiplatform
9-
9+
from google.cloud.aiplatform import base
1010
from google.cloud.aiplatform import datasets
1111
from google.cloud.aiplatform import initializer
1212
from google.cloud.aiplatform import models
@@ -271,7 +271,9 @@ def test_init_aiplatform_with_encryption_key_name_and_create_training_job(
271271
training_pipeline=true_training_pipeline,
272272
)
273273

274-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
274+
mock_model_service_get.assert_called_once_with(
275+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
276+
)
275277
assert job._gca_resource is mock_pipeline_service_get.return_value
276278
assert model_from_job._gca_resource is mock_model_service_get.return_value
277279
assert job.get_model()._gca_resource is mock_model_service_get.return_value
@@ -538,7 +540,9 @@ def test_run_call_pipeline_service_create(
538540
training_pipeline=true_training_pipeline,
539541
)
540542

541-
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
543+
mock_model_service_get.assert_called_once_with(
544+
name=_TEST_MODEL_NAME, retry=base._DEFAULT_RETRY
545+
)
542546
assert job._gca_resource is mock_pipeline_service_get.return_value
543547
assert model_from_job._gca_resource is mock_model_service_get.return_value
544548
assert job.get_model()._gca_resource is mock_model_service_get.return_value

tests/unit/aiplatform/test_custom_job.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from test_training_jobs import mock_python_package_to_gcs # noqa: F401
3030

3131
from google.cloud import aiplatform
32+
from google.cloud.aiplatform import base
3233
from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
3334
from google.cloud.aiplatform.compat.types import (
3435
custom_job_v1beta1 as gca_custom_job_v1beta1,
@@ -447,7 +448,9 @@ def test_get_custom_job(self, get_custom_job_mock):
447448

448449
job = aiplatform.CustomJob.get(_TEST_CUSTOM_JOB_NAME)
449450

450-
get_custom_job_mock.assert_called_once_with(name=_TEST_CUSTOM_JOB_NAME)
451+
get_custom_job_mock.assert_called_once_with(
452+
name=_TEST_CUSTOM_JOB_NAME, retry=base._DEFAULT_RETRY
453+
)
451454
assert (
452455
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
453456
)

tests/unit/aiplatform/test_datasets.py

+48-18
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
from google.auth import credentials as auth_credentials
2929

3030
from google.cloud import aiplatform
31-
from google.cloud import bigquery
32-
from google.cloud import storage
33-
31+
from google.cloud.aiplatform import base
3432
from google.cloud.aiplatform import compat
3533
from google.cloud.aiplatform import datasets
3634
from google.cloud.aiplatform import initializer
3735
from google.cloud.aiplatform import schema
36+
from google.cloud import bigquery
37+
from google.cloud import storage
3838

3939
from google.cloud.aiplatform_v1.services.dataset_service import (
4040
client as dataset_service_client,
@@ -474,7 +474,9 @@ def teardown_method(self):
474474
def test_init_dataset(self, get_dataset_mock):
475475
aiplatform.init(project=_TEST_PROJECT)
476476
datasets._Dataset(dataset_name=_TEST_NAME)
477-
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
477+
get_dataset_mock.assert_called_once_with(
478+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
479+
)
478480

479481
def test_init_dataset_with_id_only_with_project_and_location(
480482
self, get_dataset_mock
@@ -483,21 +485,27 @@ def test_init_dataset_with_id_only_with_project_and_location(
483485
datasets._Dataset(
484486
dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION
485487
)
486-
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
488+
get_dataset_mock.assert_called_once_with(
489+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
490+
)
487491

488492
def test_init_dataset_with_project_and_location(self, get_dataset_mock):
489493
aiplatform.init(project=_TEST_PROJECT)
490494
datasets._Dataset(
491495
dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION
492496
)
493-
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
497+
get_dataset_mock.assert_called_once_with(
498+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
499+
)
494500

495501
def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock):
496502
aiplatform.init(project=_TEST_PROJECT)
497503
datasets._Dataset(
498504
dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION
499505
)
500-
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
506+
get_dataset_mock.assert_called_once_with(
507+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
508+
)
501509

502510
def test_init_dataset_with_alt_location(self, get_dataset_tabular_gcs_mock):
503511
aiplatform.init(project=_TEST_PROJECT, location=_TEST_ALT_LOCATION)
@@ -511,7 +519,9 @@ def test_init_dataset_with_alt_location(self, get_dataset_tabular_gcs_mock):
511519

512520
assert _TEST_ALT_LOCATION != _TEST_LOCATION
513521

514-
get_dataset_tabular_gcs_mock.assert_called_once_with(name=_TEST_NAME)
522+
get_dataset_tabular_gcs_mock.assert_called_once_with(
523+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
524+
)
515525

516526
def test_init_dataset_with_project_and_alt_location(self):
517527
aiplatform.init(project=_TEST_PROJECT)
@@ -525,7 +535,9 @@ def test_init_dataset_with_project_and_alt_location(self):
525535
def test_init_dataset_with_id_only(self, get_dataset_mock):
526536
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
527537
datasets._Dataset(dataset_name=_TEST_ID)
528-
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
538+
get_dataset_mock.assert_called_once_with(
539+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
540+
)
529541

530542
@pytest.mark.usefixtures("get_dataset_without_name_mock")
531543
@patch.dict(
@@ -541,7 +553,9 @@ def test_init_dataset_with_id_only_without_project_or_location(self):
541553
def test_init_dataset_with_location_override(self, get_dataset_mock):
542554
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
543555
datasets._Dataset(dataset_name=_TEST_ID, location=_TEST_ALT_LOCATION)
544-
get_dataset_mock.assert_called_once_with(name=_TEST_ALT_NAME)
556+
get_dataset_mock.assert_called_once_with(
557+
name=_TEST_ALT_NAME, retry=base._DEFAULT_RETRY
558+
)
545559

546560
@pytest.mark.usefixtures("get_dataset_mock")
547561
def test_init_dataset_with_invalid_name(self):
@@ -764,7 +778,9 @@ def test_create_then_import(
764778
metadata=_TEST_REQUEST_METADATA,
765779
)
766780

767-
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
781+
get_dataset_mock.assert_called_once_with(
782+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
783+
)
768784

769785
import_data_mock.assert_called_once_with(
770786
name=_TEST_NAME, import_configs=[expected_import_config]
@@ -798,7 +814,9 @@ def teardown_method(self):
798814
def test_init_dataset_image(self, get_dataset_image_mock):
799815
aiplatform.init(project=_TEST_PROJECT)
800816
datasets.ImageDataset(dataset_name=_TEST_NAME)
801-
get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME)
817+
get_dataset_image_mock.assert_called_once_with(
818+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
819+
)
802820

803821
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
804822
def test_init_dataset_non_image(self):
@@ -934,7 +952,9 @@ def test_create_then_import(
934952
metadata=_TEST_REQUEST_METADATA,
935953
)
936954

937-
get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME)
955+
get_dataset_image_mock.assert_called_once_with(
956+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
957+
)
938958

939959
expected_import_config = gca_dataset.ImportDataConfig(
940960
gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
@@ -989,7 +1009,9 @@ def teardown_method(self):
9891009
def test_init_dataset_tabular(self, get_dataset_tabular_bq_mock):
9901010

9911011
datasets.TabularDataset(dataset_name=_TEST_NAME)
992-
get_dataset_tabular_bq_mock.assert_called_once_with(name=_TEST_NAME)
1012+
get_dataset_tabular_bq_mock.assert_called_once_with(
1013+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
1014+
)
9931015

9941016
@pytest.mark.usefixtures("get_dataset_image_mock")
9951017
def test_init_dataset_non_tabular(self):
@@ -1236,7 +1258,9 @@ def teardown_method(self):
12361258
def test_init_dataset_text(self, get_dataset_text_mock):
12371259
aiplatform.init(project=_TEST_PROJECT)
12381260
datasets.TextDataset(dataset_name=_TEST_NAME)
1239-
get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME)
1261+
get_dataset_text_mock.assert_called_once_with(
1262+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
1263+
)
12401264

12411265
@pytest.mark.usefixtures("get_dataset_image_mock")
12421266
def test_init_dataset_non_text(self):
@@ -1409,7 +1433,9 @@ def test_create_then_import(
14091433
metadata=_TEST_REQUEST_METADATA,
14101434
)
14111435

1412-
get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME)
1436+
get_dataset_text_mock.assert_called_once_with(
1437+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
1438+
)
14131439

14141440
expected_import_config = gca_dataset.ImportDataConfig(
14151441
gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
@@ -1463,7 +1489,9 @@ def teardown_method(self):
14631489
def test_init_dataset_video(self, get_dataset_video_mock):
14641490
aiplatform.init(project=_TEST_PROJECT)
14651491
datasets.VideoDataset(dataset_name=_TEST_NAME)
1466-
get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME)
1492+
get_dataset_video_mock.assert_called_once_with(
1493+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
1494+
)
14671495

14681496
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock")
14691497
def test_init_dataset_non_video(self):
@@ -1599,7 +1627,9 @@ def test_create_then_import(
15991627
metadata=_TEST_REQUEST_METADATA,
16001628
)
16011629

1602-
get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME)
1630+
get_dataset_video_mock.assert_called_once_with(
1631+
name=_TEST_NAME, retry=base._DEFAULT_RETRY
1632+
)
16031633

16041634
expected_import_config = gca_dataset.ImportDataConfig(
16051635
gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),

0 commit comments

Comments
 (0)