Skip to content

Commit b4b1b12

Browse files
authored
feat: Add service account support to Custom Training and Model deployment (#342)
1 parent a5fa7a2 commit b4b1b12

File tree

5 files changed

+103
-15
lines changed

5 files changed

+103
-15
lines changed

google/cloud/aiplatform/models.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def deploy(
477477
max_replica_count: int = 1,
478478
accelerator_type: Optional[str] = None,
479479
accelerator_count: Optional[int] = None,
480+
service_account: Optional[str] = None,
480481
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
481482
explanation_parameters: Optional[explain.ExplanationParameters] = None,
482483
metadata: Optional[Sequence[Tuple[str, str]]] = (),
@@ -531,6 +532,13 @@ def deploy(
531532
NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3
532533
accelerator_count (int):
533534
Optional. The number of accelerators to attach to a worker replica.
535+
service_account (str):
536+
The service account that the DeployedModel's container runs as. Specify the
537+
email address of the service account. If this service account is not
538+
specified, the container runs as a service account that doesn't have access
539+
to the resource project.
540+
Users deploying the Model must have the `iam.serviceAccounts.actAs`
541+
permission on this service account.
534542
explanation_metadata (explain.ExplanationMetadata):
535543
Optional. Metadata describing the Model's input and output for explanation.
536544
Both `explanation_metadata` and `explanation_parameters` must be
@@ -569,6 +577,7 @@ def deploy(
569577
max_replica_count=max_replica_count,
570578
accelerator_type=accelerator_type,
571579
accelerator_count=accelerator_count,
580+
service_account=service_account,
572581
explanation_metadata=explanation_metadata,
573582
explanation_parameters=explanation_parameters,
574583
metadata=metadata,
@@ -587,6 +596,7 @@ def _deploy(
587596
max_replica_count: Optional[int] = 1,
588597
accelerator_type: Optional[str] = None,
589598
accelerator_count: Optional[int] = None,
599+
service_account: Optional[str] = None,
590600
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
591601
explanation_parameters: Optional[explain.ExplanationParameters] = None,
592602
metadata: Optional[Sequence[Tuple[str, str]]] = (),
@@ -641,6 +651,13 @@ def _deploy(
641651
NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3
642652
accelerator_count (int):
643653
Optional. The number of accelerators to attach to a worker replica.
654+
service_account (str):
655+
The service account that the DeployedModel's container runs as. Specify the
656+
email address of the service account. If this service account is not
657+
specified, the container runs as a service account that doesn't have access
658+
to the resource project.
659+
Users deploying the Model must have the `iam.serviceAccounts.actAs`
660+
permission on this service account.
644661
explanation_metadata (explain.ExplanationMetadata):
645662
Optional. Metadata describing the Model's input and output for explanation.
646663
Both `explanation_metadata` and `explanation_parameters` must be
@@ -677,6 +694,7 @@ def _deploy(
677694
max_replica_count=max_replica_count,
678695
accelerator_type=accelerator_type,
679696
accelerator_count=accelerator_count,
697+
service_account=service_account,
680698
explanation_metadata=explanation_metadata,
681699
explanation_parameters=explanation_parameters,
682700
metadata=metadata,
@@ -701,6 +719,7 @@ def _deploy_call(
701719
max_replica_count: Optional[int] = 1,
702720
accelerator_type: Optional[str] = None,
703721
accelerator_count: Optional[int] = None,
722+
service_account: Optional[str] = None,
704723
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
705724
explanation_parameters: Optional[explain.ExplanationParameters] = None,
706725
metadata: Optional[Sequence[Tuple[str, str]]] = (),
@@ -753,6 +772,13 @@ def _deploy_call(
753772
is not provided, the larger value of min_replica_count or 1 will
754773
be used. If value provided is smaller than min_replica_count, it
755774
will automatically be increased to be min_replica_count.
775+
service_account (str):
776+
The service account that the DeployedModel's container runs as. Specify the
777+
email address of the service account. If this service account is not
778+
specified, the container runs as a service account that doesn't have access
779+
to the resource project.
780+
Users deploying the Model must have the `iam.serviceAccounts.actAs`
781+
permission on this service account.
756782
explanation_metadata (explain.ExplanationMetadata):
757783
Optional. Metadata describing the Model's input and output for explanation.
758784
Both `explanation_metadata` and `explanation_parameters` must be
@@ -788,6 +814,12 @@ def _deploy_call(
788814
gca_endpoint = gca_endpoint_v1beta1
789815
gca_machine_resources = gca_machine_resources_v1beta1
790816

817+
deployed_model = gca_endpoint.DeployedModel(
818+
model=model_resource_name,
819+
display_name=deployed_model_display_name,
820+
service_account=service_account,
821+
)
822+
791823
if machine_type:
792824
machine_spec = gca_machine_resources.MachineSpec(machine_type=machine_type)
793825

@@ -796,26 +828,17 @@ def _deploy_call(
796828
machine_spec.accelerator_type = accelerator_type
797829
machine_spec.accelerator_count = accelerator_count
798830

799-
dedicated_resources = gca_machine_resources.DedicatedResources(
831+
deployed_model.dedicated_resources = gca_machine_resources.DedicatedResources(
800832
machine_spec=machine_spec,
801833
min_replica_count=min_replica_count,
802834
max_replica_count=max_replica_count,
803835
)
804-
deployed_model = gca_endpoint.DeployedModel(
805-
dedicated_resources=dedicated_resources,
806-
model=model_resource_name,
807-
display_name=deployed_model_display_name,
808-
)
836+
809837
else:
810-
automatic_resources = gca_machine_resources.AutomaticResources(
838+
deployed_model.automatic_resources = gca_machine_resources.AutomaticResources(
811839
min_replica_count=min_replica_count,
812840
max_replica_count=max_replica_count,
813841
)
814-
deployed_model = gca_endpoint.DeployedModel(
815-
automatic_resources=automatic_resources,
816-
model=model_resource_name,
817-
display_name=deployed_model_display_name,
818-
)
819842

820843
# Service will throw error if both metadata and parameters are not provided
821844
if explanation_metadata and explanation_parameters:
@@ -1493,6 +1516,7 @@ def deploy(
14931516
max_replica_count: Optional[int] = 1,
14941517
accelerator_type: Optional[str] = None,
14951518
accelerator_count: Optional[int] = None,
1519+
service_account: Optional[str] = None,
14961520
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
14971521
explanation_parameters: Optional[explain.ExplanationParameters] = None,
14981522
metadata: Optional[Sequence[Tuple[str, str]]] = (),
@@ -1548,6 +1572,13 @@ def deploy(
15481572
NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3
15491573
accelerator_count (int):
15501574
Optional. The number of accelerators to attach to a worker replica.
1575+
service_account (str):
1576+
The service account that the DeployedModel's container runs as. Specify the
1577+
email address of the service account. If this service account is not
1578+
specified, the container runs as a service account that doesn't have access
1579+
to the resource project.
1580+
Users deploying the Model must have the `iam.serviceAccounts.actAs`
1581+
permission on this service account.
15511582
explanation_metadata (explain.ExplanationMetadata):
15521583
Optional. Metadata describing the Model's input and output for explanation.
15531584
Both `explanation_metadata` and `explanation_parameters` must be
@@ -1601,6 +1632,7 @@ def deploy(
16011632
max_replica_count=max_replica_count,
16021633
accelerator_type=accelerator_type,
16031634
accelerator_count=accelerator_count,
1635+
service_account=service_account,
16041636
explanation_metadata=explanation_metadata,
16051637
explanation_parameters=explanation_parameters,
16061638
metadata=metadata,
@@ -1621,6 +1653,7 @@ def _deploy(
16211653
max_replica_count: Optional[int] = 1,
16221654
accelerator_type: Optional[str] = None,
16231655
accelerator_count: Optional[int] = None,
1656+
service_account: Optional[str] = None,
16241657
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
16251658
explanation_parameters: Optional[explain.ExplanationParameters] = None,
16261659
metadata: Optional[Sequence[Tuple[str, str]]] = (),
@@ -1676,6 +1709,13 @@ def _deploy(
16761709
NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3
16771710
accelerator_count (int):
16781711
Optional. The number of accelerators to attach to a worker replica.
1712+
service_account (str):
1713+
The service account that the DeployedModel's container runs as. Specify the
1714+
email address of the service account. If this service account is not
1715+
specified, the container runs as a service account that doesn't have access
1716+
to the resource project.
1717+
Users deploying the Model must have the `iam.serviceAccounts.actAs`
1718+
permission on this service account.
16791719
explanation_metadata (explain.ExplanationMetadata):
16801720
Optional. Metadata describing the Model's input and output for explanation.
16811721
Both `explanation_metadata` and `explanation_parameters` must be
@@ -1732,6 +1772,7 @@ def _deploy(
17321772
max_replica_count=max_replica_count,
17331773
accelerator_type=accelerator_type,
17341774
accelerator_count=accelerator_count,
1775+
service_account=service_account,
17351776
explanation_metadata=explanation_metadata,
17361777
explanation_parameters=explanation_parameters,
17371778
metadata=metadata,

google/cloud/aiplatform/training_jobs.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,7 @@ def _prepare_training_task_inputs_and_output_dir(
15171517
self,
15181518
worker_pool_specs: _DistributedTrainingSpec,
15191519
base_output_dir: Optional[str] = None,
1520+
service_account: Optional[str] = None,
15201521
) -> Tuple[Dict, str]:
15211522
"""Prepares training task inputs and output directory for custom job.
15221523
@@ -1526,6 +1527,9 @@ def _prepare_training_task_inputs_and_output_dir(
15261527
base_output_dir (str):
15271528
GCS output directory of job. If not provided a
15281529
timestamped directory in the staging directory will be used.
1530+
service_account (str):
1531+
Specifies the service account for workload run-as account.
1532+
Users submitting jobs must have act-as permission on this run-as account.
15291533
Returns:
15301534
Training task inputs and Output directory for custom job.
15311535
"""
@@ -1542,6 +1546,9 @@ def _prepare_training_task_inputs_and_output_dir(
15421546
"baseOutputDirectory": {"output_uri_prefix": base_output_dir},
15431547
}
15441548

1549+
if service_account:
1550+
training_task_inputs["serviceAccount"] = service_account
1551+
15451552
return training_task_inputs, base_output_dir
15461553

15471554
@property
@@ -1787,6 +1794,7 @@ def run(
17871794
annotation_schema_uri: Optional[str] = None,
17881795
model_display_name: Optional[str] = None,
17891796
base_output_dir: Optional[str] = None,
1797+
service_account: Optional[str] = None,
17901798
bigquery_destination: Optional[str] = None,
17911799
args: Optional[List[Union[str, float, int]]] = None,
17921800
replica_count: int = 0,
@@ -1864,6 +1872,9 @@ def run(
18641872
base_output_dir (str):
18651873
GCS output directory of job. If not provided a
18661874
timestamped directory in the staging directory will be used.
1875+
service_account (str):
1876+
Specifies the service account for workload run-as account.
1877+
Users submitting jobs must have act-as permission on this run-as account.
18671878
bigquery_destination (str):
18681879
Provide this field if `dataset` is a BiqQuery dataset.
18691880
The BigQuery project location where the training data is to
@@ -1942,6 +1953,7 @@ def run(
19421953
managed_model=managed_model,
19431954
args=args,
19441955
base_output_dir=base_output_dir,
1956+
service_account=service_account,
19451957
bigquery_destination=bigquery_destination,
19461958
training_fraction_split=training_fraction_split,
19471959
validation_fraction_split=validation_fraction_split,
@@ -1967,6 +1979,7 @@ def _run(
19671979
managed_model: Optional[gca_model.Model] = None,
19681980
args: Optional[List[Union[str, float, int]]] = None,
19691981
base_output_dir: Optional[str] = None,
1982+
service_account: Optional[str] = None,
19701983
bigquery_destination: Optional[str] = None,
19711984
training_fraction_split: float = 0.8,
19721985
validation_fraction_split: float = 0.1,
@@ -2000,6 +2013,9 @@ def _run(
20002013
base_output_dir (str):
20012014
GCS output directory of job. If not provided a
20022015
timestamped directory in the staging directory will be used.
2016+
service_account (str):
2017+
Specifies the service account for workload run-as account.
2018+
Users submitting jobs must have act-as permission on this run-as account.
20032019
bigquery_destination (str):
20042020
Provide this field if `dataset` is a BiqQuery dataset.
20052021
The BigQuery project location where the training data is to
@@ -2063,7 +2079,7 @@ def _run(
20632079
training_task_inputs,
20642080
base_output_dir,
20652081
) = self._prepare_training_task_inputs_and_output_dir(
2066-
worker_pool_specs, base_output_dir
2082+
worker_pool_specs, base_output_dir, service_account
20672083
)
20682084

20692085
model = self._run_job(
@@ -2306,6 +2322,7 @@ def run(
23062322
annotation_schema_uri: Optional[str] = None,
23072323
model_display_name: Optional[str] = None,
23082324
base_output_dir: Optional[str] = None,
2325+
service_account: Optional[str] = None,
23092326
bigquery_destination: Optional[str] = None,
23102327
args: Optional[List[Union[str, float, int]]] = None,
23112328
replica_count: int = 0,
@@ -2383,6 +2400,9 @@ def run(
23832400
base_output_dir (str):
23842401
GCS output directory of job. If not provided a
23852402
timestamped directory in the staging directory will be used.
2403+
service_account (str):
2404+
Specifies the service account for workload run-as account.
2405+
Users submitting jobs must have act-as permission on this run-as account.
23862406
bigquery_destination (str):
23872407
Provide this field if `dataset` is a BiqQuery dataset.
23882408
The BigQuery project location where the training data is to
@@ -2460,6 +2480,7 @@ def run(
24602480
managed_model=managed_model,
24612481
args=args,
24622482
base_output_dir=base_output_dir,
2483+
service_account=service_account,
24632484
bigquery_destination=bigquery_destination,
24642485
training_fraction_split=training_fraction_split,
24652486
validation_fraction_split=validation_fraction_split,
@@ -2484,6 +2505,7 @@ def _run(
24842505
managed_model: Optional[gca_model.Model] = None,
24852506
args: Optional[List[Union[str, float, int]]] = None,
24862507
base_output_dir: Optional[str] = None,
2508+
service_account: Optional[str] = None,
24872509
bigquery_destination: Optional[str] = None,
24882510
training_fraction_split: float = 0.8,
24892511
validation_fraction_split: float = 0.1,
@@ -2514,6 +2536,9 @@ def _run(
25142536
base_output_dir (str):
25152537
GCS output directory of job. If not provided a
25162538
timestamped directory in the staging directory will be used.
2539+
service_account (str):
2540+
Specifies the service account for workload run-as account.
2541+
Users submitting jobs must have act-as permission on this run-as account.
25172542
bigquery_destination (str):
25182543
The BigQuery project location where the training data is to
25192544
be written to. In the given project a new dataset is created
@@ -2570,7 +2595,7 @@ def _run(
25702595
training_task_inputs,
25712596
base_output_dir,
25722597
) = self._prepare_training_task_inputs_and_output_dir(
2573-
worker_pool_specs, base_output_dir
2598+
worker_pool_specs, base_output_dir, service_account
25742599
)
25752600

25762601
model = self._run_job(
@@ -3573,6 +3598,7 @@ def run(
35733598
annotation_schema_uri: Optional[str] = None,
35743599
model_display_name: Optional[str] = None,
35753600
base_output_dir: Optional[str] = None,
3601+
service_account: Optional[str] = None,
35763602
bigquery_destination: Optional[str] = None,
35773603
args: Optional[List[Union[str, float, int]]] = None,
35783604
replica_count: int = 0,
@@ -3650,6 +3676,9 @@ def run(
36503676
base_output_dir (str):
36513677
GCS output directory of job. If not provided a
36523678
timestamped directory in the staging directory will be used.
3679+
service_account (str):
3680+
Specifies the service account for workload run-as account.
3681+
Users submitting jobs must have act-as permission on this run-as account.
36533682
bigquery_destination (str):
36543683
Provide this field if `dataset` is a BiqQuery dataset.
36553684
The BigQuery project location where the training data is to
@@ -3722,6 +3751,7 @@ def run(
37223751
managed_model=managed_model,
37233752
args=args,
37243753
base_output_dir=base_output_dir,
3754+
service_account=service_account,
37253755
training_fraction_split=training_fraction_split,
37263756
validation_fraction_split=validation_fraction_split,
37273757
test_fraction_split=test_fraction_split,
@@ -3746,6 +3776,7 @@ def _run(
37463776
managed_model: Optional[gca_model.Model] = None,
37473777
args: Optional[List[Union[str, float, int]]] = None,
37483778
base_output_dir: Optional[str] = None,
3779+
service_account: Optional[str] = None,
37493780
training_fraction_split: float = 0.8,
37503781
validation_fraction_split: float = 0.1,
37513782
test_fraction_split: float = 0.1,
@@ -3777,6 +3808,9 @@ def _run(
37773808
base_output_dir (str):
37783809
GCS output directory of job. If not provided a
37793810
timestamped directory in the staging directory will be used.
3811+
service_account (str):
3812+
Specifies the service account for workload run-as account.
3813+
Users submitting jobs must have act-as permission on this run-as account.
37803814
training_fraction_split (float):
37813815
The fraction of the input data that is to be
37823816
used to train the Model.
@@ -3819,7 +3853,7 @@ def _run(
38193853
training_task_inputs,
38203854
base_output_dir,
38213855
) = self._prepare_training_task_inputs_and_output_dir(
3822-
worker_pool_specs, base_output_dir
3856+
worker_pool_specs, base_output_dir, service_account
38233857
)
38243858

38253859
model = self._run_job(

0 commit comments

Comments
 (0)