Skip to content

Commit 572a27c

Browse files
authored
feat: Expose additional attributes into Vertex SDK to close gap with GAPIC (#477)
* Add most missing fields * Add tests for get trainingjob subclass * Drop Dataset len, add more attrs, update docstrings * flake8 lint * Address reviewer comments * Switch 'an' to 'a' when referencing Vertex AI * Address comments, move base attrs to subclasses * Drop unused import * Add test to ensure supported training schemas are always unique * Address reviewer comments
1 parent e3cbdd8 commit 572a27c

File tree

7 files changed

+428
-40
lines changed

7 files changed

+428
-40
lines changed

google/cloud/aiplatform/base.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from google.auth import credentials as auth_credentials
4343
from google.cloud.aiplatform import initializer
4444
from google.cloud.aiplatform import utils
45-
45+
from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec
4646

4747
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
4848

@@ -563,6 +563,23 @@ def update_time(self) -> datetime.datetime:
563563
self._sync_gca_resource()
564564
return self._gca_resource.update_time
565565

566+
@property
567+
def encryption_spec(self) -> Optional[gca_encryption_spec.EncryptionSpec]:
568+
"""Customer-managed encryption key options for this Vertex AI resource.
569+
570+
If this is set, then all resources created by this Vertex AI resource will
571+
be encrypted with the provided encryption key.
572+
"""
573+
return getattr(self._gca_resource, "encryption_spec")
574+
575+
@property
576+
def labels(self) -> Dict[str, str]:
577+
"""User-defined labels containing metadata about this resource.
578+
579+
Read more about labels at https://goo.gl/xmQnxf
580+
"""
581+
return self._gca_resource.labels
582+
566583
@property
567584
def gca_resource(self) -> proto.Message:
568585
"""The underlying resource proto represenation."""
@@ -813,7 +830,7 @@ def _construct_sdk_resource_from_gapic(
813830
814831
Args:
815832
gapic_resource (proto.Message):
816-
A GAPIC representation of an Vertex AI resource, usually
833+
A GAPIC representation of a Vertex AI resource, usually
817834
retrieved by a get_* or in a list_* API call.
818835
project (str):
819836
Optional. Project to construct SDK object from. If not set,

google/cloud/aiplatform/initializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def create_client(
267267
268268
Args:
269269
client_class (utils.VertexAiServiceClientWithOverride):
270-
(Required) An Vertex AI Service Client with optional overrides.
270+
(Required) A Vertex AI Service Client with optional overrides.
271271
credentials (auth_credentials.Credentials):
272272
Custom auth credentials. If not provided will use the current config.
273273
location_override (str): Optional location override.

google/cloud/aiplatform/jobs.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import abc
2121
import copy
22+
import datetime
2223
import sys
2324
import time
2425
import logging
@@ -28,6 +29,7 @@
2829

2930
from google.auth import credentials as auth_credentials
3031
from google.protobuf import duration_pb2 # type: ignore
32+
from google.rpc import status_pb2
3133

3234
from google.cloud import aiplatform
3335
from google.cloud.aiplatform import base
@@ -45,6 +47,7 @@
4547
batch_prediction_job as gca_bp_job_compat,
4648
batch_prediction_job_v1 as gca_bp_job_v1,
4749
batch_prediction_job_v1beta1 as gca_bp_job_v1beta1,
50+
completion_stats as gca_completion_stats,
4851
custom_job as gca_custom_job_compat,
4952
custom_job_v1beta1 as gca_custom_job_v1beta1,
5053
explanation_v1beta1 as gca_explanation_v1beta1,
@@ -139,6 +142,27 @@ def state(self) -> gca_job_state.JobState:
139142

140143
return self._gca_resource.state
141144

145+
@property
146+
def start_time(self) -> Optional[datetime.datetime]:
147+
"""Time when the Job resource entered the `JOB_STATE_RUNNING` for the
148+
first time."""
149+
self._sync_gca_resource()
150+
return getattr(self._gca_resource, "start_time")
151+
152+
@property
153+
def end_time(self) -> Optional[datetime.datetime]:
154+
"""Time when the Job resource entered the `JOB_STATE_SUCCEEDED`,
155+
`JOB_STATE_FAILED`, or `JOB_STATE_CANCELLED` state."""
156+
self._sync_gca_resource()
157+
return getattr(self._gca_resource, "end_time")
158+
159+
@property
160+
def error(self) -> Optional[status_pb2.Status]:
161+
"""Detailed error info for this Job resource. Only populated when the
162+
Job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`."""
163+
self._sync_gca_resource()
164+
return getattr(self._gca_resource, "error")
165+
142166
@property
143167
@abc.abstractmethod
144168
def _job_type(cls) -> str:
@@ -302,6 +326,27 @@ def __init__(
302326
credentials=credentials,
303327
)
304328

329+
@property
330+
def output_info(self,) -> Optional[aiplatform.gapic.BatchPredictionJob.OutputInfo]:
331+
"""Information describing the output of this job, including output location
332+
into which prediction output is written.
333+
334+
This is only available for batch predicition jobs that have run successfully.
335+
"""
336+
return self._gca_resource.output_info
337+
338+
@property
339+
def partial_failures(self) -> Optional[Sequence[status_pb2.Status]]:
340+
"""Partial failures encountered. For example, single files that can't be read.
341+
This field never exceeds 20 entries. Status details fields contain standard
342+
GCP error details."""
343+
return getattr(self._gca_resource, "partial_failures")
344+
345+
@property
346+
def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
347+
"""Statistics on completed and failed prediction instances."""
348+
return getattr(self._gca_resource, "completion_stats")
349+
305350
@classmethod
306351
def create(
307352
cls,
@@ -842,7 +887,7 @@ def get(
842887
location: Optional[str] = None,
843888
credentials: Optional[auth_credentials.Credentials] = None,
844889
) -> "_RunnableJob":
845-
"""Get an Vertex AI Job for the given resource_name.
890+
"""Get a Vertex AI Job for the given resource_name.
846891
847892
Args:
848893
resource_name (str):
@@ -858,7 +903,7 @@ def get(
858903
credentials set in aiplatform.init.
859904
860905
Returns:
861-
An Vertex AI Job.
906+
A Vertex AI Job.
862907
"""
863908
self = cls._empty_constructor(
864909
project=project,
@@ -887,7 +932,7 @@ class CustomJob(_RunnableJob):
887932

888933
_resource_noun = "customJobs"
889934
_getter_method = "get_custom_job"
890-
_list_method = "list_custom_job"
935+
_list_method = "list_custom_jobs"
891936
_cancel_method = "cancel_custom_job"
892937
_delete_method = "delete_custom_job"
893938
_job_type = "training"
@@ -987,6 +1032,20 @@ def __init__(
9871032
),
9881033
)
9891034

1035+
@property
1036+
def network(self) -> Optional[str]:
1037+
"""The full name of the Google Compute Engine
1038+
[network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
1039+
CustomJob should be peered.
1040+
1041+
Takes the format `projects/{project}/global/networks/{network}`. Where
1042+
{project} is a project number, as in `12345`, and {network} is a network name.
1043+
1044+
Private services access must already be configured for the network. If left
1045+
unspecified, the CustomJob is not peered with any network.
1046+
"""
1047+
return getattr(self._gca_resource, "network")
1048+
9901049
@classmethod
9911050
def from_local_script(
9921051
cls,
@@ -1157,7 +1216,7 @@ def run(
11571216
distributed training jobs that are not resilient
11581217
to workers leaving and joining a job.
11591218
tensorboard (str):
1160-
Optional. The name of an Vertex AI
1219+
Optional. The name of a Vertex AI
11611220
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
11621221
resource to which this CustomJob will upload Tensorboard
11631222
logs. Format:
@@ -1444,6 +1503,20 @@ def __init__(
14441503
),
14451504
)
14461505

1506+
@property
1507+
def network(self) -> Optional[str]:
1508+
"""The full name of the Google Compute Engine
1509+
[network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
1510+
HyperparameterTuningJob should be peered.
1511+
1512+
Takes the format `projects/{project}/global/networks/{network}`. Where
1513+
{project} is a project number, as in `12345`, and {network} is a network name.
1514+
1515+
Private services access must already be configured for the network. If left
1516+
unspecified, the HyperparameterTuningJob is not peered with any network.
1517+
"""
1518+
return getattr(self._gca_resource.trial_job_spec, "network")
1519+
14471520
@base.optional_sync()
14481521
def run(
14491522
self,
@@ -1473,7 +1546,7 @@ def run(
14731546
distributed training jobs that are not resilient
14741547
to workers leaving and joining a job.
14751548
tensorboard (str):
1476-
Optional. The name of an Vertex AI
1549+
Optional. The name of a Vertex AI
14771550
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
14781551
resource to which this CustomJob will upload Tensorboard
14791552
logs. Format:

google/cloud/aiplatform/models.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
1919

2020
from google.api_core import operation
21+
from google.api_core import exceptions as api_exceptions
2122
from google.auth import credentials as auth_credentials
2223

24+
from google.cloud import aiplatform
2325
from google.cloud.aiplatform import base
2426
from google.cloud.aiplatform import compat
2527
from google.cloud.aiplatform import explain
@@ -119,6 +121,33 @@ def __init__(
119121
credentials=credentials,
120122
)
121123

124+
@property
125+
def traffic_split(self) -> Dict[str, int]:
126+
"""A map from a DeployedModel's ID to the percentage of this Endpoint's
127+
traffic that should be forwarded to that DeployedModel.
128+
129+
If a DeployedModel's ID is not listed in this map, then it receives no traffic.
130+
131+
The traffic percentage values must add up to 100, or map must be empty if
132+
the Endpoint is to not accept any traffic at a moment.
133+
"""
134+
self._sync_gca_resource()
135+
return dict(self._gca_resource.traffic_split)
136+
137+
@property
138+
def network(self) -> Optional[str]:
139+
"""The full name of the Google Compute Engine
140+
[network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
141+
Endpoint should be peered.
142+
143+
Takes the format `projects/{project}/global/networks/{network}`. Where
144+
{project} is a project number, as in `12345`, and {network} is a network name.
145+
146+
Private services access must already be configured for the network. If left
147+
unspecified, the Endpoint is not peered with any network.
148+
"""
149+
return getattr(self._gca_resource, "network")
150+
122151
@classmethod
123152
def create(
124153
cls,
@@ -1211,12 +1240,13 @@ class Model(base.VertexAiResourceNounWithFutureManager):
12111240
_delete_method = "delete_model"
12121241

12131242
@property
1214-
def uri(self):
1215-
"""Uri of the model."""
1216-
return self._gca_resource.artifact_uri
1243+
def uri(self) -> Optional[str]:
1244+
"""Path to the directory containing the Model artifact and any of its
1245+
supporting files. Not present for AutoML Models."""
1246+
return self._gca_resource.artifact_uri or None
12171247

12181248
@property
1219-
def description(self):
1249+
def description(self) -> str:
12201250
"""Description of the model."""
12211251
return self._gca_resource.description
12221252

@@ -1240,6 +1270,98 @@ def supported_export_formats(
12401270
for export_format in self._gca_resource.supported_export_formats
12411271
}
12421272

1273+
@property
1274+
def supported_deployment_resources_types(
1275+
self,
1276+
) -> List[aiplatform.gapic.Model.DeploymentResourcesType]:
1277+
"""List of deployment resource types accepted for this Model.
1278+
1279+
When this Model is deployed, its prediction resources are described by
1280+
the `prediction_resources` field of the objects returned by
1281+
`Endpoint.list_models()`. Because not all Models support all resource
1282+
configuration types, the configuration types this Model supports are
1283+
listed here.
1284+
1285+
If no configuration types are listed, the Model cannot be
1286+
deployed to an `Endpoint` and does not support online predictions
1287+
(`Endpoint.predict()` or `Endpoint.explain()`). Such a Model can serve
1288+
predictions by using a `BatchPredictionJob`, if it has at least one entry
1289+
each in `Model.supported_input_storage_formats` and
1290+
`Model.supported_output_storage_formats`."""
1291+
return list(self._gca_resource.supported_deployment_resources_types)
1292+
1293+
@property
1294+
def supported_input_storage_formats(self) -> List[str]:
1295+
"""The formats this Model supports in the `input_config` field of a
1296+
`BatchPredictionJob`. If `Model.predict_schemata.instance_schema_uri`
1297+
exists, the instances should be given as per that schema.
1298+
1299+
[Read the docs for more on batch prediction formats](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions#batch_request_input)
1300+
1301+
If this Model doesn't support any of these formats it means it cannot be
1302+
used with a `BatchPredictionJob`. However, if it has
1303+
`supported_deployment_resources_types`, it could serve online predictions
1304+
by using `Endpoint.predict()` or `Endpoint.explain()`.
1305+
"""
1306+
return list(self._gca_resource.supported_input_storage_formats)
1307+
1308+
@property
1309+
def supported_output_storage_formats(self) -> List[str]:
1310+
"""The formats this Model supports in the `output_config` field of a
1311+
`BatchPredictionJob`.
1312+
1313+
If both `Model.predict_schemata.instance_schema_uri` and
1314+
`Model.predict_schemata.prediction_schema_uri` exist, the predictions
1315+
are returned together with their instances. In other words, the
1316+
prediction has the original instance data first, followed by the actual
1317+
prediction content (as per the schema).
1318+
1319+
[Read the docs for more on batch prediction formats](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions)
1320+
1321+
If this Model doesn't support any of these formats it means it cannot be
1322+
used with a `BatchPredictionJob`. However, if it has
1323+
`supported_deployment_resources_types`, it could serve online predictions
1324+
by using `Endpoint.predict()` or `Endpoint.explain()`.
1325+
"""
1326+
return list(self._gca_resource.supported_output_storage_formats)
1327+
1328+
@property
1329+
def predict_schemata(self) -> Optional[aiplatform.gapic.PredictSchemata]:
1330+
"""The schemata that describe formats of the Model's predictions and
1331+
explanations, if available."""
1332+
return getattr(self._gca_resource, "predict_schemata")
1333+
1334+
@property
1335+
def training_job(self) -> Optional["aiplatform.training_jobs._TrainingJob"]:
1336+
"""The TrainingJob that uploaded this Model, if any.
1337+
1338+
Raises:
1339+
api_core.exceptions.NotFound: If the Model's training job resource
1340+
cannot be found on the Vertex service.
1341+
"""
1342+
job_name = getattr(self._gca_resource, "training_pipeline")
1343+
1344+
if not job_name:
1345+
return None
1346+
1347+
try:
1348+
return aiplatform.training_jobs._TrainingJob._get_and_return_subclass(
1349+
resource_name=job_name,
1350+
project=self.project,
1351+
location=self.location,
1352+
credentials=self.credentials,
1353+
)
1354+
except api_exceptions.NotFound:
1355+
raise api_exceptions.NotFound(
1356+
f"The training job used to create this model could not be found: {job_name}"
1357+
)
1358+
1359+
@property
1360+
def container_spec(self) -> Optional[aiplatform.gapic.ModelContainerSpec]:
1361+
"""The specification of the container that is to be used when deploying
1362+
this Model. Not present for AutoML Models."""
1363+
return getattr(self._gca_resource, "container_spec")
1364+
12431365
def __init__(
12441366
self,
12451367
model_name: str,

0 commit comments

Comments
 (0)