@@ -116,6 +116,11 @@ def submit(
116
116
* ,
117
117
output_uri_prefix : Optional [str ] = None ,
118
118
job_display_name : Optional [str ] = None ,
119
+ machine_type : Optional [str ] = None ,
120
+ accelerator_type : Optional [str ] = None ,
121
+ accelerator_count : Optional [int ] = None ,
122
+ starting_replica_count : Optional [int ] = None ,
123
+ max_replica_count : Optional [int ] = None ,
119
124
) -> "BatchPredictionJob" :
120
125
"""Submits a batch prediction job for a GenAI model.
121
126
@@ -142,6 +147,16 @@ def submit(
142
147
The user-defined name of the BatchPredictionJob.
143
148
The name can be up to 128 characters long and can be consist
144
149
of any UTF-8 characters.
150
+ machine_type (str):
151
+ The type of machine for running batch prediction job.
152
+ accelerator_type (str):
153
+ The type of accelerator for running batch prediction job.
154
+ accelerator_count (int):
155
+ The number of accelerators for running batch prediction job.
156
+ starting_replica_count (int):
157
+ The starting number of replica for running batch prediction job.
158
+ max_replica_count (int):
159
+ The maximum number of replica for running batch prediction job.
145
160
146
161
Returns:
147
162
Instantiated BatchPredictionJob.
@@ -219,6 +234,11 @@ def submit(
219
234
bigquery_source = bigquery_source ,
220
235
gcs_destination_prefix = gcs_destination_prefix ,
221
236
bigquery_destination_prefix = bigquery_destination_prefix ,
237
+ machine_type = machine_type ,
238
+ accelerator_type = accelerator_type ,
239
+ accelerator_count = accelerator_count ,
240
+ starting_replica_count = starting_replica_count ,
241
+ max_replica_count = max_replica_count ,
222
242
)
223
243
job = cls ._empty_constructor ()
224
244
job ._gca_resource = aiplatform_job ._gca_resource
@@ -281,27 +301,29 @@ def _reconcile_model_name(cls, model_name: str) -> str:
281
301
if "/" not in model_name :
282
302
# model name (e.g., gemini-1.0-pro)
283
303
if model_name .startswith ("gemini" ):
284
- model_name = "publishers/google/models/" + model_name
304
+ return "publishers/google/models/" + model_name
285
305
else :
286
306
raise ValueError (
287
307
"Abbreviated model names are only supported for Gemini models. "
288
308
"Please provide the full publisher model name."
289
309
)
290
310
elif model_name .startswith ("models/" ):
291
311
# publisher model name (e.g., models/gemini-1.0-pro)
292
- model_name = "publishers/google/" + model_name
312
+ return "publishers/google/" + model_name
293
313
elif (
294
- # publisher model full name
295
- not model_name .startswith ("publishers/google/models/" )
296
- and not model_name .startswith ("publishers/meta/models/" )
297
- and not model_name .startswith ("publishers/anthropic/models/" )
298
- # tuned model full resource name
299
- and not re .search (_GEMINI_TUNED_MODEL_PATTERN , model_name )
314
+ re .match (
315
+ r"^publishers/(?P<publisher>[^/]+)/models/(?P<model>[^@]+)@(?P<version>[^@]+)$" ,
316
+ model_name ,
317
+ )
318
+ or model_name .startswith ("publishers/google/models/" )
319
+ or model_name .startswith ("publishers/meta/models/" )
320
+ or model_name .startswith ("publishers/anthropic/models/" )
321
+ or re .search (_GEMINI_TUNED_MODEL_PATTERN , model_name )
300
322
):
323
+ return model_name
324
+ else :
301
325
raise ValueError (f"Invalid format for model name: { model_name } ." )
302
326
303
- return model_name
304
-
305
327
@classmethod
306
328
def _is_genai_model (cls , model_name : str ) -> bool :
307
329
"""Validates if a given model_name represents a GenAI model."""
@@ -326,6 +348,13 @@ def _is_genai_model(cls, model_name: str) -> bool:
326
348
# Model is a claude model.
327
349
return True
328
350
351
+ if re .match (
352
+ r"^publishers/(?P<publisher>[^/]+)/models/(?P<model>[^@]+)@(?P<version>[^@]+)$" ,
353
+ model_name ,
354
+ ):
355
+ # Model is a self-hosted model.
356
+ return True
357
+
329
358
return False
330
359
331
360
@classmethod
0 commit comments