28
28
from google .auth import credentials as auth_credentials
29
29
30
30
from google .cloud import aiplatform
31
- from google .cloud import bigquery
32
- from google .cloud import storage
33
-
31
+ from google .cloud .aiplatform import base
34
32
from google .cloud .aiplatform import compat
35
33
from google .cloud .aiplatform import datasets
36
34
from google .cloud .aiplatform import initializer
37
35
from google .cloud .aiplatform import schema
36
+ from google .cloud import bigquery
37
+ from google .cloud import storage
38
38
39
39
from google .cloud .aiplatform_v1 .services .dataset_service import (
40
40
client as dataset_service_client ,
@@ -474,7 +474,9 @@ def teardown_method(self):
474
474
def test_init_dataset (self , get_dataset_mock ):
475
475
aiplatform .init (project = _TEST_PROJECT )
476
476
datasets ._Dataset (dataset_name = _TEST_NAME )
477
- get_dataset_mock .assert_called_once_with (name = _TEST_NAME )
477
+ get_dataset_mock .assert_called_once_with (
478
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
479
+ )
478
480
479
481
def test_init_dataset_with_id_only_with_project_and_location (
480
482
self , get_dataset_mock
@@ -483,21 +485,27 @@ def test_init_dataset_with_id_only_with_project_and_location(
483
485
datasets ._Dataset (
484
486
dataset_name = _TEST_ID , project = _TEST_PROJECT , location = _TEST_LOCATION
485
487
)
486
- get_dataset_mock .assert_called_once_with (name = _TEST_NAME )
488
+ get_dataset_mock .assert_called_once_with (
489
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
490
+ )
487
491
488
492
def test_init_dataset_with_project_and_location (self , get_dataset_mock ):
489
493
aiplatform .init (project = _TEST_PROJECT )
490
494
datasets ._Dataset (
491
495
dataset_name = _TEST_NAME , project = _TEST_PROJECT , location = _TEST_LOCATION
492
496
)
493
- get_dataset_mock .assert_called_once_with (name = _TEST_NAME )
497
+ get_dataset_mock .assert_called_once_with (
498
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
499
+ )
494
500
495
501
def test_init_dataset_with_alt_project_and_location (self , get_dataset_mock ):
496
502
aiplatform .init (project = _TEST_PROJECT )
497
503
datasets ._Dataset (
498
504
dataset_name = _TEST_NAME , project = _TEST_ALT_PROJECT , location = _TEST_LOCATION
499
505
)
500
- get_dataset_mock .assert_called_once_with (name = _TEST_NAME )
506
+ get_dataset_mock .assert_called_once_with (
507
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
508
+ )
501
509
502
510
def test_init_dataset_with_alt_location (self , get_dataset_tabular_gcs_mock ):
503
511
aiplatform .init (project = _TEST_PROJECT , location = _TEST_ALT_LOCATION )
@@ -511,7 +519,9 @@ def test_init_dataset_with_alt_location(self, get_dataset_tabular_gcs_mock):
511
519
512
520
assert _TEST_ALT_LOCATION != _TEST_LOCATION
513
521
514
- get_dataset_tabular_gcs_mock .assert_called_once_with (name = _TEST_NAME )
522
+ get_dataset_tabular_gcs_mock .assert_called_once_with (
523
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
524
+ )
515
525
516
526
def test_init_dataset_with_project_and_alt_location (self ):
517
527
aiplatform .init (project = _TEST_PROJECT )
@@ -525,7 +535,9 @@ def test_init_dataset_with_project_and_alt_location(self):
525
535
def test_init_dataset_with_id_only (self , get_dataset_mock ):
526
536
aiplatform .init (project = _TEST_PROJECT , location = _TEST_LOCATION )
527
537
datasets ._Dataset (dataset_name = _TEST_ID )
528
- get_dataset_mock .assert_called_once_with (name = _TEST_NAME )
538
+ get_dataset_mock .assert_called_once_with (
539
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
540
+ )
529
541
530
542
@pytest .mark .usefixtures ("get_dataset_without_name_mock" )
531
543
@patch .dict (
@@ -541,7 +553,9 @@ def test_init_dataset_with_id_only_without_project_or_location(self):
541
553
def test_init_dataset_with_location_override (self , get_dataset_mock ):
542
554
aiplatform .init (project = _TEST_PROJECT , location = _TEST_LOCATION )
543
555
datasets ._Dataset (dataset_name = _TEST_ID , location = _TEST_ALT_LOCATION )
544
- get_dataset_mock .assert_called_once_with (name = _TEST_ALT_NAME )
556
+ get_dataset_mock .assert_called_once_with (
557
+ name = _TEST_ALT_NAME , retry = base ._DEFAULT_RETRY
558
+ )
545
559
546
560
@pytest .mark .usefixtures ("get_dataset_mock" )
547
561
def test_init_dataset_with_invalid_name (self ):
@@ -764,7 +778,9 @@ def test_create_then_import(
764
778
metadata = _TEST_REQUEST_METADATA ,
765
779
)
766
780
767
- get_dataset_mock .assert_called_once_with (name = _TEST_NAME )
781
+ get_dataset_mock .assert_called_once_with (
782
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
783
+ )
768
784
769
785
import_data_mock .assert_called_once_with (
770
786
name = _TEST_NAME , import_configs = [expected_import_config ]
@@ -798,7 +814,9 @@ def teardown_method(self):
798
814
def test_init_dataset_image (self , get_dataset_image_mock ):
799
815
aiplatform .init (project = _TEST_PROJECT )
800
816
datasets .ImageDataset (dataset_name = _TEST_NAME )
801
- get_dataset_image_mock .assert_called_once_with (name = _TEST_NAME )
817
+ get_dataset_image_mock .assert_called_once_with (
818
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
819
+ )
802
820
803
821
@pytest .mark .usefixtures ("get_dataset_tabular_bq_mock" )
804
822
def test_init_dataset_non_image (self ):
@@ -934,7 +952,9 @@ def test_create_then_import(
934
952
metadata = _TEST_REQUEST_METADATA ,
935
953
)
936
954
937
- get_dataset_image_mock .assert_called_once_with (name = _TEST_NAME )
955
+ get_dataset_image_mock .assert_called_once_with (
956
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
957
+ )
938
958
939
959
expected_import_config = gca_dataset .ImportDataConfig (
940
960
gcs_source = gca_io .GcsSource (uris = [_TEST_SOURCE_URI_GCS ]),
@@ -989,7 +1009,9 @@ def teardown_method(self):
989
1009
def test_init_dataset_tabular (self , get_dataset_tabular_bq_mock ):
990
1010
991
1011
datasets .TabularDataset (dataset_name = _TEST_NAME )
992
- get_dataset_tabular_bq_mock .assert_called_once_with (name = _TEST_NAME )
1012
+ get_dataset_tabular_bq_mock .assert_called_once_with (
1013
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
1014
+ )
993
1015
994
1016
@pytest .mark .usefixtures ("get_dataset_image_mock" )
995
1017
def test_init_dataset_non_tabular (self ):
@@ -1236,7 +1258,9 @@ def teardown_method(self):
1236
1258
def test_init_dataset_text (self , get_dataset_text_mock ):
1237
1259
aiplatform .init (project = _TEST_PROJECT )
1238
1260
datasets .TextDataset (dataset_name = _TEST_NAME )
1239
- get_dataset_text_mock .assert_called_once_with (name = _TEST_NAME )
1261
+ get_dataset_text_mock .assert_called_once_with (
1262
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
1263
+ )
1240
1264
1241
1265
@pytest .mark .usefixtures ("get_dataset_image_mock" )
1242
1266
def test_init_dataset_non_text (self ):
@@ -1409,7 +1433,9 @@ def test_create_then_import(
1409
1433
metadata = _TEST_REQUEST_METADATA ,
1410
1434
)
1411
1435
1412
- get_dataset_text_mock .assert_called_once_with (name = _TEST_NAME )
1436
+ get_dataset_text_mock .assert_called_once_with (
1437
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
1438
+ )
1413
1439
1414
1440
expected_import_config = gca_dataset .ImportDataConfig (
1415
1441
gcs_source = gca_io .GcsSource (uris = [_TEST_SOURCE_URI_GCS ]),
@@ -1463,7 +1489,9 @@ def teardown_method(self):
1463
1489
def test_init_dataset_video (self , get_dataset_video_mock ):
1464
1490
aiplatform .init (project = _TEST_PROJECT )
1465
1491
datasets .VideoDataset (dataset_name = _TEST_NAME )
1466
- get_dataset_video_mock .assert_called_once_with (name = _TEST_NAME )
1492
+ get_dataset_video_mock .assert_called_once_with (
1493
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
1494
+ )
1467
1495
1468
1496
@pytest .mark .usefixtures ("get_dataset_tabular_bq_mock" )
1469
1497
def test_init_dataset_non_video (self ):
@@ -1599,7 +1627,9 @@ def test_create_then_import(
1599
1627
metadata = _TEST_REQUEST_METADATA ,
1600
1628
)
1601
1629
1602
- get_dataset_video_mock .assert_called_once_with (name = _TEST_NAME )
1630
+ get_dataset_video_mock .assert_called_once_with (
1631
+ name = _TEST_NAME , retry = base ._DEFAULT_RETRY
1632
+ )
1603
1633
1604
1634
expected_import_config = gca_dataset .ImportDataConfig (
1605
1635
gcs_source = gca_io .GcsSource (uris = [_TEST_SOURCE_URI_GCS ]),
0 commit comments