Skip to content

Commit 1f1226f

Browse files
committed
feat: add pipeline client init and run to vertex AI
1 parent f40f322 commit 1f1226f

File tree

5 files changed

+784
-1
lines changed

5 files changed

+784
-1
lines changed

google/cloud/aiplatform/compat/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
model_evaluation_slice as model_evaluation_slice_v1beta1,
4545
model_service as model_service_v1beta1,
4646
operation as operation_v1beta1,
47+
pipeline_job as pipeline_job_v1beta1,
4748
pipeline_service as pipeline_service_v1beta1,
4849
pipeline_state as pipeline_state_v1beta1,
4950
prediction_service as prediction_service_v1beta1,
@@ -158,6 +159,7 @@
158159
model_evaluation_slice_v1beta1,
159160
model_service_v1beta1,
160161
operation_v1beta1,
162+
pipeline_job_v1beta1,
161163
pipeline_service_v1beta1,
162164
pipeline_state_v1beta1,
163165
prediction_service_v1beta1,
Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import time
19+
from typing import Optional, Dict, List
20+
21+
import logging
22+
import re
23+
import sys
24+
25+
from google.auth import credentials as auth_credentials
26+
27+
from google.cloud.aiplatform import base
28+
from google.cloud.aiplatform import compat
29+
from google.cloud.aiplatform import initializer
30+
from google.cloud.aiplatform import utils
31+
from google.cloud.aiplatform.utils import pipeline_runtime_config_builder
32+
33+
from google.cloud.aiplatform.compat.services import pipeline_service_client
34+
from google.cloud.aiplatform.compat.types import (
35+
pipeline_job_v1beta1 as gca_pipeline_job_v1beta1,
36+
pipeline_state_v1beta1 as gca_pipeline_state_v1beta1,
37+
)
38+
39+
from google.rpc import code_pb2
40+
41+
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
42+
_LOGGER = base.Logger(__name__)
43+
44+
_PIPELINE_COMPLETE_STATES = set(
45+
[
46+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED,
47+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED,
48+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_CANCELLED,
49+
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_PAUSED,
50+
]
51+
)
52+
53+
_PIPELINE_CLIENT_VERSION='v1beta1'
54+
55+
# AIPlatformPipelines service API job name relative name prefix pattern.
56+
_JOB_NAME_PATTERN = '{parent}/pipelineJobs/{job_id}'
57+
58+
# Pattern for valid names used as a Vertex resource name.
59+
_VALID_NAME_PATTERN = re.compile('^[a-z][-a-z0-9]{0,127}$')
60+
61+
def _set_enable_caching_value(pipeline_spec: Dict,
62+
enable_caching: bool) -> None:
63+
"""Sets pipeline tasks caching options.
64+
Args:
65+
pipeline_spec: The dictionary of pipeline spec.
66+
enable_caching: Whether to enable caching.
67+
"""
68+
for component in [pipeline_spec['root']] + list(
69+
pipeline_spec['components'].values()):
70+
if 'dag' in component:
71+
for task in component['dag']['tasks'].values():
72+
task['cachingOptions'] = {'enableCache': enable_caching}
73+
74+
75+
class PipelineJob(base.VertexAiResourceNounWithFutureManager):
76+
77+
client_class = utils.PipelineClientWithOverride
78+
_is_client_prediction_client = False
79+
80+
_resource_noun = "pipelineJobs"
81+
_getter_method = "get_pipeline_job"
82+
_list_method = "list_pipeline_jobs"
83+
_cancel_method = "cancel_pipeline_job"
84+
_delete_method = "delete_pipeline_job"
85+
86+
def __init__(
87+
self,
88+
display_name: str,
89+
job_spec_path: str,
90+
job_id: Optional[str] = None,
91+
pipeline_root: Optional[str] = None,
92+
parameter_values: Optional[Dict] = None,
93+
enable_caching: bool = True,
94+
encryption_spec_key_name: Optional[str] = None,
95+
network: Optional[str] = None,
96+
labels: Optional[Dict] = None,
97+
service_account: Optional[str] = None,
98+
credentials: Optional[auth_credentials.Credentials] = None,
99+
project: Optional[str] = None,
100+
location: Optional[str] = None,
101+
):
102+
"""Retrieves a PipelineJob resource and instantiates its
103+
representation.
104+
105+
Args:
106+
display_name (str):
107+
Required. The user-defined name of this Pipeline.
108+
job_spec_path (str):
109+
The path of PipelineJob JSON file. It can be a local path or a
110+
GS URI. Example: "gs://project.name"
111+
job_id (Optional[str]):
112+
Optionally, the user can provide the unique ID of the job run.
113+
If not specified, pipeline name + timestamp will be used.
114+
pipeline_root (Optional[str]):
115+
Optionally the user can override the pipeline root
116+
specified during the compile time. Default to be staging bucket.
117+
parameter_values (Optional[Dict]):
118+
The mapping from runtime parameter names to its values that
119+
control the pipeline run.
120+
enable_caching (bool):
121+
Required. Whether to turn on caching for the run. Defaults to True.
122+
encryption_spec_key_name (Optional[str]):
123+
Optional. The Cloud KMS resource identifier of the customer
124+
managed encryption key used to protect the job. Has the
125+
form:
126+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
127+
The key needs to be in the same region as where the compute
128+
resource is created.
129+
130+
If this is set, then all
131+
resources created by the BatchPredictionJob will
132+
be encrypted with the provided encryption key.
133+
134+
Overrides encryption_spec_key_name set in aiplatform.init.
135+
labels (Optional[Dict]):
136+
The user defined metadata to organize PipelineJob.
137+
credentials (Optional[auth_credentials.Credentials]):
138+
Custom credentials to use to create this batch prediction
139+
job. Overrides credentials set in aiplatform.init.
140+
project: Optional[str] = None,
141+
Optional project to retrieve PipelineJob from. If not set,
142+
project set in aiplatform.init will be used.
143+
location: Optional[str] = None,
144+
Optional location to retrieve PipelineJob from. If not set,
145+
location set in aiplatform.init will be used.
146+
"""
147+
utils.validate_display_name(display_name)
148+
149+
super().__init__(project=project, location=location, credentials=credentials)
150+
151+
self._parent = initializer.global_config.common_location_path(
152+
project=project, location=location
153+
)
154+
pipeline_root = pipeline_root or initializer.global_config.staging_bucket
155+
pipeline_spec = utils.load_json(job_spec_path)
156+
pipeline_name = pipeline_spec['pipelineSpec']['pipelineInfo']['name']
157+
job_id = job_id or '{pipeline_name}-{timestamp}'.format(
158+
pipeline_name=re.sub('[^-0-9a-z]+', '-',
159+
pipeline_name.lower()).lstrip('-').rstrip('-'),
160+
timestamp=_get_current_time().strftime('%Y%m%d%H%M%S'))
161+
if not _VALID_NAME_PATTERN.match(job_id):
162+
raise ValueError(
163+
'Generated job ID: {} is illegal as a Vertex pipelines job ID. '
164+
'Expecting an ID following the regex pattern '
165+
'"[a-z][-a-z0-9]{{0,127}}"'.format(job_id))
166+
167+
job_name = _JOB_NAME_PATTERN.format(parent=self._parent, job_id=job_id)
168+
169+
pipeline_spec['name'] = job_name
170+
pipeline_spec['displayName'] = job_id
171+
172+
builder = pipeline_runtime_config_builder.PipelineRuntimeConfigBuilder.from_job_spec_json(
173+
pipeline_spec)
174+
builder.update_pipeline_root(pipeline_root)
175+
builder.update_runtime_parameters(parameter_values)
176+
177+
runtime_config = builder.build()
178+
pipeline_spec['runtimeConfig'] = runtime_config
179+
180+
_set_enable_caching_value(pipeline_spec['pipelineSpec'], enable_caching)
181+
182+
if encryption_spec_key_name is not None:
183+
pipeline_spec['encryptionSpec'] = {'kmsKeyName': encryption_spec_key_name}
184+
if service_account is not None:
185+
pipeline_spec['serviceAccount'] = service_account
186+
if network is not None:
187+
pipeline_spec['network'] = network
188+
189+
if labels:
190+
if not isinstance(labels, Dict):
191+
raise ValueError(
192+
'Expect labels to be a mapping of string key value pairs. '
193+
'Got "{}" of type "{}"'.format(labels, type(labels)))
194+
for k, v in labels.items():
195+
if not isinstance(k, str) or not isinstance(v, str):
196+
raise ValueError(
197+
'Expect labels to be a mapping of string key value pairs. '
198+
'Got "{}".'.format(labels))
199+
200+
pipeline_spec['labels'] = labels
201+
202+
self._gca_resource = gca_pipeline_job_v1beta1.PipelineJob(
203+
display_name=display_name,
204+
pipeline_spec=pipeline_spec,
205+
labels=labels,
206+
runtime_config=None,
207+
encryption_spec=initializer.global_config.get_encryption_spec(
208+
encryption_spec_key_name=encryption_spec_key_name
209+
),
210+
service_account=service_account,
211+
network=network,
212+
)
213+
214+
@base.optional_sync()
215+
def run(
216+
self,
217+
service_account: Optional[str] = None,
218+
network: Optional[str] = None,
219+
sync: bool = True,
220+
) -> None:
221+
"""Run this configured PipelineJob.
222+
Args:
223+
service_account (str):
224+
Optional. Specifies the service account for workload run-as account.
225+
Users submitting jobs must have act-as permission on this run-as account.
226+
network (str):
227+
Optional. The full name of the Compute Engine network to which the job
228+
should be peered. For example, projects/12345/global/networks/myVPC.
229+
Private services access must already be configured for the network.
230+
If left unspecified, the job is not peered with any network.
231+
sync (bool):
232+
Whether to execute this method synchronously. If False, this method
233+
will unblock and it will be executed in a concurrent Future.
234+
"""
235+
236+
if service_account:
237+
self._gca_resource.pipeline_spec.service_account = service_account
238+
239+
if network:
240+
self._gca_resource.pipeline_spec.network = network
241+
242+
_LOGGER.log_create_with_lro(self.__class__)
243+
244+
self._gca_resource = self.api_client.select_version(_PIPELINE_CLIENT_VERSION).create_pipeline_job(
245+
parent=self._parent,
246+
pipeline_job=self._gca_resource
247+
)
248+
249+
_LOGGER.log_create_complete_with_getter(
250+
self.__class__, self._gca_resource, "pipeline_job"
251+
)
252+
253+
_LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri())
254+
255+
self._block_until_complete()
256+
257+
@property
258+
def pipeline_spec(self):
259+
return self._gca_resource.pipeline_spec
260+
261+
@property
262+
def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
263+
"""Current pipeline state."""
264+
265+
if self._assert_has_run():
266+
return
267+
268+
self._sync_gca_resource()
269+
return self._gca_resource.state
270+
271+
@property
272+
def _has_run(self) -> bool:
273+
"""Helper property to check if this pipeline job has been run."""
274+
return self._gca_resource is not None
275+
276+
def _assert_has_run(self) -> bool:
277+
"""Helper method to assert that this pipeline has run."""
278+
if not self._has_run:
279+
if self._is_waiting_to_run():
280+
return True
281+
raise RuntimeError(
282+
"PipelineJob has not been launched. You must run this"
283+
" PipelineJob using PipelineJob.run. "
284+
)
285+
return False
286+
287+
@property
288+
def has_failed(self) -> bool:
289+
"""Returns True if pipeline has failed.
290+
291+
False otherwise.
292+
"""
293+
self._assert_has_run()
294+
return self.state == gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_FAILED
295+
296+
def _dashboard_uri(self) -> str:
297+
"""Helper method to compose the dashboard uri where pipeline can be
298+
viewed."""
299+
fields = utils.extract_fields_from_resource_name(self.resource_name)
300+
url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/pipelines/runs/{fields.id}?project={fields.project}"
301+
return url
302+
303+
def _sync_gca_resource(self):
304+
"""Helper method to sync the local gca_source against the service."""
305+
306+
self._gca_resource = self.api_client.select_version(_PIPELINE_CLIENT_VERSION).get_pipeline_job(
307+
name=self.resource_name
308+
)
309+
310+
def _block_until_complete(self):
311+
"""Helper method to block and check on job until complete."""
312+
313+
# Used these numbers so failures surface fast
314+
wait = 5 # start at five seconds
315+
log_wait = 5
316+
max_wait = 60 * 5 # 5 minute wait
317+
multiplier = 2 # scale wait by 2 every iteration
318+
319+
previous_time = time.time()
320+
while self.state not in _PIPELINE_COMPLETE_STATES:
321+
current_time = time.time()
322+
if current_time - previous_time >= log_wait:
323+
_LOGGER.info(
324+
"%s %s current state:\n%s"
325+
% (
326+
self.__class__.__name__,
327+
self._gca_resource.name,
328+
self._gca_resource.state,
329+
)
330+
)
331+
log_wait = min(log_wait * multiplier, max_wait)
332+
previous_time = current_time
333+
time.sleep(wait)
334+
335+
self._raise_failure()
336+
337+
_LOGGER.log_action_completed_against_resource("run", "completed", self)
338+
339+
if self._gca_resource.name and not self.has_failed:
340+
_LOGGER.info("Pipeline Job available at:\n%s" % self._dashboard_uri())
341+
342+
def _raise_failure(self):
343+
"""Helper method to raise failure if PipelineJob fails.
344+
345+
Raises:
346+
RuntimeError: If pipeline failed.
347+
"""
348+
349+
if self._gca_resource.error.code != code_pb2.OK:
350+
raise RuntimeError("Pipeline failed with:\n%s" % self._gca_resource.error)
351+

0 commit comments

Comments
 (0)