Skip to content

Commit 627fdf9

Browse files
sararobsasha-gitg
andauthored
feat: add enable_simple_view to PipelineJob.list() (#1614)
* feat: add enable_simple_view to PipelineJob.list() * updates to pipelinejob.list read_mask * run linter * update to read_mask * add placeholder for read_mask to system tests * unit test fix * add system test for read_mask filter * move read mask fields to constants file * add read_mask docstrings * remove class name check Co-authored-by: sasha-gitg <[email protected]>
1 parent a3cc5a3 commit 627fdf9

File tree

5 files changed

+136
-0
lines changed

5 files changed

+136
-0
lines changed

google/cloud/aiplatform/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec
4949
from google.cloud.aiplatform.constants import base as base_constants
5050
from google.protobuf import json_format
51+
from google.protobuf import field_mask_pb2 as field_mask
5152

5253
# This is the default retry callback to be used with get methods.
5354
_DEFAULT_RETRY = retry.Retry()
@@ -1030,6 +1031,7 @@ def _list(
10301031
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
10311032
filter: Optional[str] = None,
10321033
order_by: Optional[str] = None,
1034+
read_mask: Optional[field_mask.FieldMask] = None,
10331035
project: Optional[str] = None,
10341036
location: Optional[str] = None,
10351037
credentials: Optional[auth_credentials.Credentials] = None,
@@ -1052,6 +1054,14 @@ def _list(
10521054
Optional. A comma-separated list of fields to order by, sorted in
10531055
ascending order. Use "desc" after a field name for descending.
10541056
Supported fields: `display_name`, `create_time`, `update_time`
1057+
read_mask (field_mask.FieldMask):
1058+
Optional. A FieldMask with a list of strings passed via `paths`
1059+
indicating which fields to return for each resource in the response.
1060+
For example, passing
1061+
field_mask.FieldMask(paths=["create_time", "update_time"])
1062+
as `read_mask` would result in each returned VertexAiResourceNoun
1063+
in the result list only having the "create_time" and
1064+
"update_time" attributes.
10551065
project (str):
10561066
Optional. Project to retrieve list from. If not set, project
10571067
set in aiplatform.init will be used.
@@ -1067,6 +1077,7 @@ def _list(
10671077
Returns:
10681078
List[VertexAiResourceNoun] - A list of SDK resource objects
10691079
"""
1080+
10701081
resource = cls._empty_constructor(
10711082
project=project, location=location, credentials=credentials
10721083
)
@@ -1083,6 +1094,10 @@ def _list(
10831094
),
10841095
}
10851096

1097+
# `read_mask` is only passed from PipelineJob.list() for now
1098+
if read_mask is not None:
1099+
list_request["read_mask"] = read_mask
1100+
10861101
if filter:
10871102
list_request["filter"] = filter
10881103

@@ -1105,6 +1120,7 @@ def _list_with_local_order(
11051120
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
11061121
filter: Optional[str] = None,
11071122
order_by: Optional[str] = None,
1123+
read_mask: Optional[field_mask.FieldMask] = None,
11081124
project: Optional[str] = None,
11091125
location: Optional[str] = None,
11101126
credentials: Optional[auth_credentials.Credentials] = None,
@@ -1127,6 +1143,14 @@ def _list_with_local_order(
11271143
Optional. A comma-separated list of fields to order by, sorted in
11281144
ascending order. Use "desc" after a field name for descending.
11291145
Supported fields: `display_name`, `create_time`, `update_time`
1146+
read_mask (field_mask.FieldMask):
1147+
Optional. A FieldMask with a list of strings passed via `paths`
1148+
indicating which fields to return for each resource in the response.
1149+
For example, passing
1150+
field_mask.FieldMask(paths=["create_time", "update_time"])
1151+
as `read_mask` would result in each returned VertexAiResourceNoun
1152+
in the result list only having the "create_time" and
1153+
"update_time" attributes.
11301154
project (str):
11311155
Optional. Project to retrieve list from. If not set, project
11321156
set in aiplatform.init will be used.
@@ -1145,6 +1169,7 @@ def _list_with_local_order(
11451169
cls_filter=cls_filter,
11461170
filter=filter,
11471171
order_by=None, # This method will handle the ordering locally
1172+
read_mask=read_mask,
11481173
project=project,
11491174
location=location,
11501175
credentials=credentials,

google/cloud/aiplatform/constants/pipeline.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,20 @@
3737

3838
# Pattern for an Artifact Registry URL.
3939
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
40+
41+
# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
42+
_READ_MASK_FIELDS = [
43+
"name",
44+
"state",
45+
"display_name",
46+
"pipeline_spec.pipeline_info",
47+
"create_time",
48+
"start_time",
49+
"end_time",
50+
"update_time",
51+
"labels",
52+
"template_uri",
53+
"template_metadata.version",
54+
"job_detail.pipeline_run_context",
55+
"job_detail.pipeline_context",
56+
]

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from google.cloud.aiplatform.utils import yaml_utils
3939
from google.cloud.aiplatform.utils import pipeline_utils
4040
from google.protobuf import json_format
41+
from google.protobuf import field_mask_pb2 as field_mask
4142

4243
from google.cloud.aiplatform.compat.types import (
4344
pipeline_job as gca_pipeline_job,
@@ -56,6 +57,8 @@
5657
# Pattern for an Artifact Registry URL.
5758
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
5859

60+
_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS
61+
5962

6063
def _get_current_time() -> datetime.datetime:
6164
"""Gets the current timestamp."""
@@ -509,6 +512,7 @@ def list(
509512
cls,
510513
filter: Optional[str] = None,
511514
order_by: Optional[str] = None,
515+
enable_simple_view: Optional[bool] = False,
512516
project: Optional[str] = None,
513517
location: Optional[str] = None,
514518
credentials: Optional[auth_credentials.Credentials] = None,
@@ -530,6 +534,17 @@ def list(
530534
Optional. A comma-separated list of fields to order by, sorted in
531535
ascending order. Use "desc" after a field name for descending.
532536
Supported fields: `display_name`, `create_time`, `update_time`
537+
enable_simple_view (bool):
538+
Optional. Whether to pass the `read_mask` parameter to the list call.
539+
This will improve the performance of calling list(). However, the
540+
returned PipelineJob list will not include all fields for each PipelineJob.
541+
Setting this to True will exclude the following fields in your response:
542+
`runtime_config`, `service_account`, `network`, and some subfields of
543+
`pipeline_spec` and `job_detail`. The following fields will be included in
544+
each PipelineJob resource in your response: `state`, `display_name`,
545+
`pipeline_spec.pipeline_info`, `create_time`, `start_time`, `end_time`,
546+
`update_time`, `labels`, `template_uri`, `template_metadata.version`,
547+
`job_detail.pipeline_run_context`, `job_detail.pipeline_context`.
533548
project (str):
534549
Optional. Project to retrieve list from. If not set, project
535550
set in aiplatform.init will be used.
@@ -544,9 +559,18 @@ def list(
544559
List[PipelineJob] - A list of PipelineJob resource objects
545560
"""
546561

562+
read_mask_fields = None
563+
564+
if enable_simple_view:
565+
read_mask_fields = field_mask.FieldMask(paths=_READ_MASK_FIELDS)
566+
_LOGGER.warn(
567+
"By enabling simple view, the PipelineJob resources returned from this method will not contain all fields."
568+
)
569+
547570
return cls._list_with_local_order(
548571
filter=filter,
549572
order_by=order_by,
573+
read_mask=read_mask_fields,
550574
project=project,
551575
location=location,
552576
credentials=credentials,

tests/system/aiplatform/test_pipeline_job.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from google.cloud import aiplatform
2121
from tests.system.aiplatform import e2e_base
2222

23+
from google.protobuf.json_format import MessageToDict
24+
2325

2426
@pytest.mark.usefixtures("tear_down_resources")
2527
class TestPipelineJob(e2e_base.TestEndToEnd):
@@ -59,3 +61,14 @@ def training_pipeline(number_of_epochs: int = 10):
5961
shared_state.setdefault("resources", []).append(job)
6062

6163
job.wait()
64+
65+
list_with_read_mask = aiplatform.PipelineJob.list(enable_simple_view=True)
66+
list_without_read_mask = aiplatform.PipelineJob.list()
67+
68+
# enable_simple_view=True should apply the `read_mask` filter to limit PipelineJob fields returned
69+
assert "serviceAccount" in MessageToDict(
70+
list_without_read_mask[0].gca_resource._pb
71+
)
72+
assert "serviceAccount" not in MessageToDict(
73+
list_with_read_mask[0].gca_resource._pb
74+
)

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud import aiplatform
2929
from google.cloud.aiplatform import base
3030
from google.cloud.aiplatform import initializer
31+
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
3132
from google.cloud.aiplatform_v1 import Context as GapicContext
3233
from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore
3334
from google.cloud.aiplatform.metadata import constants
@@ -37,6 +38,7 @@
3738
from google.cloud.aiplatform.utils import gcs_utils
3839
from google.cloud import storage
3940
from google.protobuf import json_format
41+
from google.protobuf import field_mask_pb2 as field_mask
4042

4143
from google.cloud.aiplatform.compat.services import (
4244
pipeline_service_client,
@@ -62,6 +64,9 @@
6264
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"
6365

6466
_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"
67+
_TEST_PIPELINE_JOB_LIST_READ_MASK = field_mask.FieldMask(
68+
paths=pipeline_constants._READ_MASK_FIELDS
69+
)
6570

6671
_TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"}
6772
_TEST_PIPELINE_PARAMETER_VALUES = {
@@ -332,6 +337,17 @@ def mock_pipeline_service_list():
332337
with mock.patch.object(
333338
pipeline_service_client.PipelineServiceClient, "list_pipeline_jobs"
334339
) as mock_list_pipeline_jobs:
340+
mock_list_pipeline_jobs.return_value = [
341+
make_pipeline_job(
342+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
343+
),
344+
make_pipeline_job(
345+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
346+
),
347+
make_pipeline_job(
348+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
349+
),
350+
]
335351
yield mock_list_pipeline_jobs
336352

337353

@@ -1354,6 +1370,47 @@ def test_list_pipeline_job(
13541370
request={"parent": _TEST_PARENT}
13551371
)
13561372

1373+
@pytest.mark.usefixtures(
1374+
"mock_pipeline_service_create",
1375+
"mock_pipeline_service_get",
1376+
"mock_pipeline_bucket_exists",
1377+
)
1378+
@pytest.mark.parametrize(
1379+
"job_spec",
1380+
[
1381+
_TEST_PIPELINE_SPEC_JSON,
1382+
_TEST_PIPELINE_SPEC_YAML,
1383+
_TEST_PIPELINE_JOB,
1384+
_TEST_PIPELINE_SPEC_LEGACY_JSON,
1385+
_TEST_PIPELINE_SPEC_LEGACY_YAML,
1386+
_TEST_PIPELINE_JOB_LEGACY,
1387+
],
1388+
)
1389+
def test_list_pipeline_job_with_read_mask(
1390+
self, mock_pipeline_service_list, mock_load_yaml_and_json
1391+
):
1392+
aiplatform.init(
1393+
project=_TEST_PROJECT,
1394+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1395+
credentials=_TEST_CREDENTIALS,
1396+
)
1397+
1398+
job = pipeline_jobs.PipelineJob(
1399+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1400+
template_path=_TEST_TEMPLATE_PATH,
1401+
job_id=_TEST_PIPELINE_JOB_ID,
1402+
)
1403+
1404+
job.run()
1405+
job.list(enable_simple_view=True)
1406+
1407+
mock_pipeline_service_list.assert_called_once_with(
1408+
request={
1409+
"parent": _TEST_PARENT,
1410+
"read_mask": _TEST_PIPELINE_JOB_LIST_READ_MASK,
1411+
},
1412+
)
1413+
13571414
@pytest.mark.usefixtures(
13581415
"mock_pipeline_service_create",
13591416
"mock_pipeline_service_get",

0 commit comments

Comments
 (0)