32
32
from google .cloud import aiplatform
33
33
from google .cloud .aiplatform import base
34
34
from google .cloud .aiplatform import compat
35
- from google .cloud .aiplatform import constants
36
- from google .cloud .aiplatform import initializer
37
- from google .cloud .aiplatform import hyperparameter_tuning
38
- from google .cloud .aiplatform import utils
39
- from google .cloud .aiplatform .utils import console_utils
40
- from google .cloud .aiplatform .utils import source_utils
41
- from google .cloud .aiplatform .utils import worker_spec_utils
42
-
43
- from google .cloud .aiplatform .compat .services import job_service_client
44
35
from google .cloud .aiplatform .compat .types import (
45
36
batch_prediction_job as gca_bp_job_compat ,
46
37
batch_prediction_job_v1 as gca_bp_job_v1 ,
58
49
machine_resources_v1beta1 as gca_machine_resources_v1beta1 ,
59
50
study as gca_study_compat ,
60
51
)
52
+ from google .cloud .aiplatform import constants
53
+ from google .cloud .aiplatform import initializer
54
+ from google .cloud .aiplatform import hyperparameter_tuning
55
+ from google .cloud .aiplatform import utils
56
+ from google .cloud .aiplatform .utils import console_utils
57
+ from google .cloud .aiplatform .utils import source_utils
58
+ from google .cloud .aiplatform .utils import worker_spec_utils
61
59
62
60
63
61
_LOGGER = base .Logger (__name__ )
@@ -352,7 +350,7 @@ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
352
350
def create (
353
351
cls ,
354
352
job_display_name : str ,
355
- model_name : str ,
353
+ model_name : Union [ str , "aiplatform.Model" ] ,
356
354
instances_format : str = "jsonl" ,
357
355
predictions_format : str = "jsonl" ,
358
356
gcs_source : Optional [Union [str , Sequence [str ]]] = None ,
@@ -384,10 +382,12 @@ def create(
384
382
Required. The user-defined name of the BatchPredictionJob.
385
383
The name can be up to 128 characters long and can be consist
386
384
of any UTF-8 characters.
387
- model_name (str):
385
+ model_name (Union[ str, aiplatform.Model] ):
388
386
Required. A fully-qualified model resource name or model ID.
389
387
Example: "projects/123/locations/us-central1/models/456" or
390
388
"456" when project and location are initialized or passed.
389
+
390
+ Or an instance of aiplatform.Model.
391
391
instances_format (str):
392
392
Required. The format in which instances are given, must be one
393
393
of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
@@ -533,15 +533,17 @@ def create(
533
533
"""
534
534
535
535
utils .validate_display_name (job_display_name )
536
+
536
537
if labels :
537
538
utils .validate_labels (labels )
538
539
539
- model_name = utils .full_resource_name (
540
- resource_name = model_name ,
541
- resource_noun = "models" ,
542
- project = project ,
543
- location = location ,
544
- )
540
+ if isinstance (model_name , str ):
541
+ model_name = utils .full_resource_name (
542
+ resource_name = model_name ,
543
+ resource_noun = "models" ,
544
+ project = project ,
545
+ location = location ,
546
+ )
545
547
546
548
# Raise error if both or neither source URIs are provided
547
549
if bool (gcs_source ) == bool (bigquery_source ):
@@ -570,6 +572,7 @@ def create(
570
572
f"{ predictions_format } is not an accepted prediction format "
571
573
f"type. Please choose from: { constants .BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS } "
572
574
)
575
+
573
576
gca_bp_job = gca_bp_job_compat
574
577
gca_io = gca_io_compat
575
578
gca_machine_resources = gca_machine_resources_compat
@@ -584,7 +587,6 @@ def create(
584
587
585
588
# Required Fields
586
589
gapic_batch_prediction_job .display_name = job_display_name
587
- gapic_batch_prediction_job .model = model_name
588
590
589
591
input_config = gca_bp_job .BatchPredictionJob .InputConfig ()
590
592
output_config = gca_bp_job .BatchPredictionJob .OutputConfig ()
@@ -657,63 +659,43 @@ def create(
657
659
metadata = explanation_metadata , parameters = explanation_parameters
658
660
)
659
661
660
- # TODO (b/174502913): Support private feature once released
661
-
662
- api_client = cls . _instantiate_client ( location = location , credentials = credentials )
662
+ empty_batch_prediction_job = cls . _empty_constructor (
663
+ project = project , location = location , credentials = credentials ,
664
+ )
663
665
664
666
return cls ._create (
665
- api_client = api_client ,
666
- parent = initializer .global_config .common_location_path (
667
- project = project , location = location
668
- ),
669
- batch_prediction_job = gapic_batch_prediction_job ,
667
+ empty_batch_prediction_job = empty_batch_prediction_job ,
668
+ model_or_model_name = model_name ,
669
+ gca_batch_prediction_job = gapic_batch_prediction_job ,
670
670
generate_explanation = generate_explanation ,
671
- project = project or initializer .global_config .project ,
672
- location = location or initializer .global_config .location ,
673
- credentials = credentials or initializer .global_config .credentials ,
674
671
sync = sync ,
675
672
)
676
673
677
674
@classmethod
678
- @base .optional_sync ()
675
+ @base .optional_sync (return_input_arg = "empty_batch_prediction_job" )
679
676
def _create (
680
677
cls ,
681
- api_client : job_service_client . JobServiceClient ,
682
- parent : str ,
683
- batch_prediction_job : Union [
678
+ empty_batch_prediction_job : "BatchPredictionJob" ,
679
+ model_or_model_name : Union [ str , "aiplatform.Model" ] ,
680
+ gca_batch_prediction_job : Union [
684
681
gca_bp_job_v1beta1 .BatchPredictionJob , gca_bp_job_v1 .BatchPredictionJob
685
682
],
686
683
generate_explanation : bool ,
687
- project : str ,
688
- location : str ,
689
- credentials : Optional [auth_credentials .Credentials ],
690
684
sync : bool = True ,
691
685
) -> "BatchPredictionJob" :
692
686
"""Create a batch prediction job.
693
687
694
688
Args:
695
- api_client (dataset_service_client.DatasetServiceClient):
696
- Required. An instance of DatasetServiceClient with the correct api_endpoint
697
- already set based on user's preferences.
698
- batch_prediction_job (gca_bp_job.BatchPredictionJob):
689
+ empty_batch_prediction_job (BatchPredictionJob):
690
+ Required. BatchPredictionJob without _gca_resource populated.
691
+ model_or_model_name (Union[str, aiplatform.Model]):
692
+ Required. Required. A fully-qualified model resource name or
693
+ an instance of aiplatform.Model.
694
+ gca_batch_prediction_job (gca_bp_job.BatchPredictionJob):
699
695
Required. a batch prediction job proto for creating a batch prediction job on Vertex AI.
700
696
generate_explanation (bool):
701
697
Required. Generate explanation along with the batch prediction
702
698
results.
703
- parent (str):
704
- Required. Also known as common location path, that usually contains the
705
- project and location that the user provided to the upstream method.
706
- Example: "projects/my-prj/locations/us-central1"
707
- project (str):
708
- Required. Project to upload this model to. Overrides project set in
709
- aiplatform.init.
710
- location (str):
711
- Required. Location to upload this model to. Overrides location set in
712
- aiplatform.init.
713
- credentials (Optional[auth_credentials.Credentials]):
714
- Custom credentials to use to upload this model. Overrides
715
- credentials set in aiplatform.init.
716
-
717
699
Returns:
718
700
(jobs.BatchPredictionJob):
719
701
Instantiated representation of the created batch prediction job.
@@ -725,21 +707,34 @@ def _create(
725
707
by Vertex AI.
726
708
"""
727
709
# select v1beta1 if explain else use default v1
710
+
711
+ parent = initializer .global_config .common_location_path (
712
+ project = empty_batch_prediction_job .project ,
713
+ location = empty_batch_prediction_job .location ,
714
+ )
715
+
716
+ model_resource_name = (
717
+ model_or_model_name
718
+ if isinstance (model_or_model_name , str )
719
+ else model_or_model_name .resource_name
720
+ )
721
+
722
+ gca_batch_prediction_job .model = model_resource_name
723
+
724
+ api_client = empty_batch_prediction_job .api_client
725
+
728
726
if generate_explanation :
729
727
api_client = api_client .select_version (compat .V1BETA1 )
730
728
731
729
_LOGGER .log_create_with_lro (cls )
732
730
733
731
gca_batch_prediction_job = api_client .create_batch_prediction_job (
734
- parent = parent , batch_prediction_job = batch_prediction_job
732
+ parent = parent , batch_prediction_job = gca_batch_prediction_job
735
733
)
736
734
737
- batch_prediction_job = cls (
738
- batch_prediction_job_name = gca_batch_prediction_job .name ,
739
- project = project ,
740
- location = location ,
741
- credentials = credentials ,
742
- )
735
+ empty_batch_prediction_job ._gca_resource = gca_batch_prediction_job
736
+
737
+ batch_prediction_job = empty_batch_prediction_job
743
738
744
739
_LOGGER .log_create_complete (cls , batch_prediction_job ._gca_resource , "bpj" )
745
740
@@ -843,6 +838,10 @@ def iter_outputs(
843
838
f"on your prediction output:\n { output_info } "
844
839
)
845
840
841
+ def wait_for_resource_creation (self ) -> None :
842
+ """Waits until resource has been created."""
843
+ self ._wait_for_resource_creation ()
844
+
846
845
847
846
class _RunnableJob (_Job ):
848
847
"""ABC to interface job as a runnable training class."""
0 commit comments