Skip to content

Commit db580ad

Browse files
authored
feat: Add wait_for_resource_creation to BatchPredictionJob and unblock async creation when model is pending creation. (#660)
1 parent 4ad67dc commit db580ad

File tree

8 files changed

+177
-86
lines changed

8 files changed

+177
-86
lines changed

README.rst

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,39 @@ Please visit `Importing models to Vertex AI`_ for a detailed overview:
274274
.. _Importing models to Vertex AI: https://cloud.google.com/vertex-ai/docs/general/import-model
275275

276276

277+
Batch Prediction
278+
----------------
279+
280+
To create a batch prediction job:
281+
282+
.. code-block:: Python
283+
284+
model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
285+
286+
batch_prediction_job = model.batch_predict(
287+
job_display_name='my-batch-prediction-job',
288+
instances_format='csv'
289+
machine_type='n1-standard-4',
290+
gcs_source=['gs://path/to/my/file.csv']
291+
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
292+
)
293+
294+
You can also create a batch prediction job asynchronously by including the `sync=False` argument:
295+
296+
.. code-block:: Python
297+
298+
batch_prediction_job = model.batch_predict(..., sync=False)
299+
300+
# wait for resource to be created
301+
batch_prediction_job.wait_for_resource_creation()
302+
303+
# get the state
304+
batch_prediction_job.state
305+
306+
# block until job is complete
307+
batch_prediction_job.wait()
308+
309+
277310
Endpoints
278311
---------
279312

google/cloud/aiplatform/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,17 +680,21 @@ def wrapper(*args, **kwargs):
680680
inspect.getfullargspec(method).annotations["return"]
681681
)
682682

683+
# object produced by the method
684+
returned_object = bound_args.arguments.get(return_input_arg)
685+
683686
# is a classmethod that creates the object and returns it
684687
if args and inspect.isclass(args[0]):
685-
# assumes classmethod is our resource noun
686-
returned_object = args[0]._empty_constructor()
688+
689+
# assumes class in classmethod is the resource noun
690+
returned_object = (
691+
args[0]._empty_constructor()
692+
if not returned_object
693+
else returned_object
694+
)
687695
self = returned_object
688696

689697
else: # instance method
690-
691-
# object produced by the method
692-
returned_object = bound_args.arguments.get(return_input_arg)
693-
694698
# if we're returning an input object
695699
if returned_object and returned_object is not self:
696700

google/cloud/aiplatform/jobs.py

Lines changed: 60 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,6 @@
3232
from google.cloud import aiplatform
3333
from google.cloud.aiplatform import base
3434
from google.cloud.aiplatform import compat
35-
from google.cloud.aiplatform import constants
36-
from google.cloud.aiplatform import initializer
37-
from google.cloud.aiplatform import hyperparameter_tuning
38-
from google.cloud.aiplatform import utils
39-
from google.cloud.aiplatform.utils import console_utils
40-
from google.cloud.aiplatform.utils import source_utils
41-
from google.cloud.aiplatform.utils import worker_spec_utils
42-
43-
from google.cloud.aiplatform.compat.services import job_service_client
4435
from google.cloud.aiplatform.compat.types import (
4536
batch_prediction_job as gca_bp_job_compat,
4637
batch_prediction_job_v1 as gca_bp_job_v1,
@@ -58,6 +49,13 @@
5849
machine_resources_v1beta1 as gca_machine_resources_v1beta1,
5950
study as gca_study_compat,
6051
)
52+
from google.cloud.aiplatform import constants
53+
from google.cloud.aiplatform import initializer
54+
from google.cloud.aiplatform import hyperparameter_tuning
55+
from google.cloud.aiplatform import utils
56+
from google.cloud.aiplatform.utils import console_utils
57+
from google.cloud.aiplatform.utils import source_utils
58+
from google.cloud.aiplatform.utils import worker_spec_utils
6159

6260

6361
_LOGGER = base.Logger(__name__)
@@ -352,7 +350,7 @@ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
352350
def create(
353351
cls,
354352
job_display_name: str,
355-
model_name: str,
353+
model_name: Union[str, "aiplatform.Model"],
356354
instances_format: str = "jsonl",
357355
predictions_format: str = "jsonl",
358356
gcs_source: Optional[Union[str, Sequence[str]]] = None,
@@ -384,10 +382,12 @@ def create(
384382
Required. The user-defined name of the BatchPredictionJob.
385383
The name can be up to 128 characters long and can be consist
386384
of any UTF-8 characters.
387-
model_name (str):
385+
model_name (Union[str, aiplatform.Model]):
388386
Required. A fully-qualified model resource name or model ID.
389387
Example: "projects/123/locations/us-central1/models/456" or
390388
"456" when project and location are initialized or passed.
389+
390+
Or an instance of aiplatform.Model.
391391
instances_format (str):
392392
Required. The format in which instances are given, must be one
393393
of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
@@ -533,15 +533,17 @@ def create(
533533
"""
534534

535535
utils.validate_display_name(job_display_name)
536+
536537
if labels:
537538
utils.validate_labels(labels)
538539

539-
model_name = utils.full_resource_name(
540-
resource_name=model_name,
541-
resource_noun="models",
542-
project=project,
543-
location=location,
544-
)
540+
if isinstance(model_name, str):
541+
model_name = utils.full_resource_name(
542+
resource_name=model_name,
543+
resource_noun="models",
544+
project=project,
545+
location=location,
546+
)
545547

546548
# Raise error if both or neither source URIs are provided
547549
if bool(gcs_source) == bool(bigquery_source):
@@ -570,6 +572,7 @@ def create(
570572
f"{predictions_format} is not an accepted prediction format "
571573
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
572574
)
575+
573576
gca_bp_job = gca_bp_job_compat
574577
gca_io = gca_io_compat
575578
gca_machine_resources = gca_machine_resources_compat
@@ -584,7 +587,6 @@ def create(
584587

585588
# Required Fields
586589
gapic_batch_prediction_job.display_name = job_display_name
587-
gapic_batch_prediction_job.model = model_name
588590

589591
input_config = gca_bp_job.BatchPredictionJob.InputConfig()
590592
output_config = gca_bp_job.BatchPredictionJob.OutputConfig()
@@ -657,63 +659,43 @@ def create(
657659
metadata=explanation_metadata, parameters=explanation_parameters
658660
)
659661

660-
# TODO (b/174502913): Support private feature once released
661-
662-
api_client = cls._instantiate_client(location=location, credentials=credentials)
662+
empty_batch_prediction_job = cls._empty_constructor(
663+
project=project, location=location, credentials=credentials,
664+
)
663665

664666
return cls._create(
665-
api_client=api_client,
666-
parent=initializer.global_config.common_location_path(
667-
project=project, location=location
668-
),
669-
batch_prediction_job=gapic_batch_prediction_job,
667+
empty_batch_prediction_job=empty_batch_prediction_job,
668+
model_or_model_name=model_name,
669+
gca_batch_prediction_job=gapic_batch_prediction_job,
670670
generate_explanation=generate_explanation,
671-
project=project or initializer.global_config.project,
672-
location=location or initializer.global_config.location,
673-
credentials=credentials or initializer.global_config.credentials,
674671
sync=sync,
675672
)
676673

677674
@classmethod
678-
@base.optional_sync()
675+
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
679676
def _create(
680677
cls,
681-
api_client: job_service_client.JobServiceClient,
682-
parent: str,
683-
batch_prediction_job: Union[
678+
empty_batch_prediction_job: "BatchPredictionJob",
679+
model_or_model_name: Union[str, "aiplatform.Model"],
680+
gca_batch_prediction_job: Union[
684681
gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob
685682
],
686683
generate_explanation: bool,
687-
project: str,
688-
location: str,
689-
credentials: Optional[auth_credentials.Credentials],
690684
sync: bool = True,
691685
) -> "BatchPredictionJob":
692686
"""Create a batch prediction job.
693687
694688
Args:
695-
api_client (dataset_service_client.DatasetServiceClient):
696-
Required. An instance of DatasetServiceClient with the correct api_endpoint
697-
already set based on user's preferences.
698-
batch_prediction_job (gca_bp_job.BatchPredictionJob):
689+
empty_batch_prediction_job (BatchPredictionJob):
690+
Required. BatchPredictionJob without _gca_resource populated.
691+
model_or_model_name (Union[str, aiplatform.Model]):
692+
Required. Required. A fully-qualified model resource name or
693+
an instance of aiplatform.Model.
694+
gca_batch_prediction_job (gca_bp_job.BatchPredictionJob):
699695
Required. a batch prediction job proto for creating a batch prediction job on Vertex AI.
700696
generate_explanation (bool):
701697
Required. Generate explanation along with the batch prediction
702698
results.
703-
parent (str):
704-
Required. Also known as common location path, that usually contains the
705-
project and location that the user provided to the upstream method.
706-
Example: "projects/my-prj/locations/us-central1"
707-
project (str):
708-
Required. Project to upload this model to. Overrides project set in
709-
aiplatform.init.
710-
location (str):
711-
Required. Location to upload this model to. Overrides location set in
712-
aiplatform.init.
713-
credentials (Optional[auth_credentials.Credentials]):
714-
Custom credentials to use to upload this model. Overrides
715-
credentials set in aiplatform.init.
716-
717699
Returns:
718700
(jobs.BatchPredictionJob):
719701
Instantiated representation of the created batch prediction job.
@@ -725,21 +707,34 @@ def _create(
725707
by Vertex AI.
726708
"""
727709
# select v1beta1 if explain else use default v1
710+
711+
parent = initializer.global_config.common_location_path(
712+
project=empty_batch_prediction_job.project,
713+
location=empty_batch_prediction_job.location,
714+
)
715+
716+
model_resource_name = (
717+
model_or_model_name
718+
if isinstance(model_or_model_name, str)
719+
else model_or_model_name.resource_name
720+
)
721+
722+
gca_batch_prediction_job.model = model_resource_name
723+
724+
api_client = empty_batch_prediction_job.api_client
725+
728726
if generate_explanation:
729727
api_client = api_client.select_version(compat.V1BETA1)
730728

731729
_LOGGER.log_create_with_lro(cls)
732730

733731
gca_batch_prediction_job = api_client.create_batch_prediction_job(
734-
parent=parent, batch_prediction_job=batch_prediction_job
732+
parent=parent, batch_prediction_job=gca_batch_prediction_job
735733
)
736734

737-
batch_prediction_job = cls(
738-
batch_prediction_job_name=gca_batch_prediction_job.name,
739-
project=project,
740-
location=location,
741-
credentials=credentials,
742-
)
735+
empty_batch_prediction_job._gca_resource = gca_batch_prediction_job
736+
737+
batch_prediction_job = empty_batch_prediction_job
743738

744739
_LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj")
745740

@@ -843,6 +838,10 @@ def iter_outputs(
843838
f"on your prediction output:\n{output_info}"
844839
)
845840

841+
def wait_for_resource_creation(self) -> None:
842+
"""Waits until resource has been created."""
843+
self._wait_for_resource_creation()
844+
846845

847846
class _RunnableJob(_Job):
848847
"""ABC to interface job as a runnable training class."""

google/cloud/aiplatform/models.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,6 @@ def undeploy(
981981
if deployed_model_id in traffic_split and traffic_split[deployed_model_id]:
982982
raise ValueError("Model being undeployed should have 0 traffic.")
983983
if sum(traffic_split.values()) != 100:
984-
# TODO(b/172678233) verify every referenced deployed model exists
985984
raise ValueError(
986985
"Sum of all traffic within traffic split needs to be 100."
987986
)
@@ -2167,11 +2166,10 @@ def batch_predict(
21672166
(jobs.BatchPredictionJob):
21682167
Instantiated representation of the created batch prediction job.
21692168
"""
2170-
self.wait()
21712169

21722170
return jobs.BatchPredictionJob.create(
21732171
job_display_name=job_display_name,
2174-
model_name=self.resource_name,
2172+
model_name=self,
21752173
instances_format=instances_format,
21762174
predictions_format=predictions_format,
21772175
gcs_source=gcs_source,

tests/system/aiplatform/e2e_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ def _temp_prefix(cls) -> str:
4343
"""
4444
pass
4545

46+
@classmethod
47+
def _make_display_name(cls, key: str) -> str:
48+
"""Helper method to make unique display_names.
49+
50+
Args:
51+
key (str): Required. Identifier for the display name.
52+
Returns:
53+
Unique display name.
54+
"""
55+
return f"{cls._temp_prefix}-{key}-{uuid.uuid4()}"
56+
4657
def setup_method(self):
4758
importlib.reload(initializer)
4859
importlib.reload(aiplatform)

0 commit comments

Comments
 (0)