Skip to content

Commit c76ac62

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Add support for self-hosted OSS models in Batch Prediction.
PiperOrigin-RevId: 752487510
1 parent b1bbba6 commit c76ac62

File tree

2 files changed

+75
-10
lines changed

2 files changed

+75
-10
lines changed

tests/unit/vertexai/test_batch_prediction.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
_TEST_CLAUDE_MODEL_RESOURCE_NAME = (
5757
f"publishers/anthropic/models/{_TEST_CLAUDE_MODEL_NAME}"
5858
)
59+
_TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME = (
60+
"publishers/google/models/gemma@gemma-2b-it"
61+
)
5962

6063
_TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
6164
_TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
@@ -589,6 +592,39 @@ def test_submit_batch_prediction_job_with_tuned_model(
589592
retry=aiplatform_base._DEFAULT_RETRY,
590593
)
591594

595+
def test_submit_batch_prediction_job_with_self_hosted_gemma_model(
596+
self,
597+
create_batch_prediction_job_mock,
598+
):
599+
job = batch_prediction.BatchPredictionJob.submit(
600+
source_model=_TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME,
601+
input_dataset=_TEST_BQ_INPUT_URI,
602+
)
603+
604+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
605+
606+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
607+
display_name=_TEST_DISPLAY_NAME,
608+
model=_TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME,
609+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
610+
instances_format="bigquery",
611+
bigquery_source=gca_io_compat.BigQuerySource(
612+
input_uri=_TEST_BQ_INPUT_URI
613+
),
614+
),
615+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
616+
bigquery_destination=gca_io_compat.BigQueryDestination(
617+
output_uri=_TEST_BQ_OUTPUT_PREFIX
618+
),
619+
predictions_format="bigquery",
620+
),
621+
)
622+
create_batch_prediction_job_mock.assert_called_once_with(
623+
parent=_TEST_PARENT,
624+
batch_prediction_job=expected_gapic_batch_prediction_job,
625+
timeout=None,
626+
)
627+
592628
def test_submit_batch_prediction_job_with_invalid_source_model(self):
593629
with pytest.raises(
594630
ValueError,

vertexai/batch_prediction/_batch_prediction.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def submit(
116116
*,
117117
output_uri_prefix: Optional[str] = None,
118118
job_display_name: Optional[str] = None,
119+
machine_type: Optional[str] = None,
120+
accelerator_type: Optional[str] = None,
121+
accelerator_count: Optional[int] = None,
122+
starting_replica_count: Optional[int] = None,
123+
max_replica_count: Optional[int] = None,
119124
) -> "BatchPredictionJob":
120125
"""Submits a batch prediction job for a GenAI model.
121126
@@ -142,6 +147,16 @@ def submit(
142147
The user-defined name of the BatchPredictionJob.
143148
The name can be up to 128 characters long and can be consist
144149
of any UTF-8 characters.
150+
machine_type (str):
151+
The type of machine for running batch prediction job.
152+
accelerator_type (str):
153+
The type of accelerator for running batch prediction job.
154+
accelerator_count (int):
155+
The number of accelerators for running batch prediction job.
156+
starting_replica_count (int):
157+
The starting number of replica for running batch prediction job.
158+
max_replica_count (int):
159+
The maximum number of replica for running batch prediction job.
145160
146161
Returns:
147162
Instantiated BatchPredictionJob.
@@ -219,6 +234,11 @@ def submit(
219234
bigquery_source=bigquery_source,
220235
gcs_destination_prefix=gcs_destination_prefix,
221236
bigquery_destination_prefix=bigquery_destination_prefix,
237+
machine_type=machine_type,
238+
accelerator_type=accelerator_type,
239+
accelerator_count=accelerator_count,
240+
starting_replica_count=starting_replica_count,
241+
max_replica_count=max_replica_count,
222242
)
223243
job = cls._empty_constructor()
224244
job._gca_resource = aiplatform_job._gca_resource
@@ -281,27 +301,29 @@ def _reconcile_model_name(cls, model_name: str) -> str:
281301
if "/" not in model_name:
282302
# model name (e.g., gemini-1.0-pro)
283303
if model_name.startswith("gemini"):
284-
model_name = "publishers/google/models/" + model_name
304+
return "publishers/google/models/" + model_name
285305
else:
286306
raise ValueError(
287307
"Abbreviated model names are only supported for Gemini models. "
288308
"Please provide the full publisher model name."
289309
)
290310
elif model_name.startswith("models/"):
291311
# publisher model name (e.g., models/gemini-1.0-pro)
292-
model_name = "publishers/google/" + model_name
312+
return "publishers/google/" + model_name
293313
elif (
294-
# publisher model full name
295-
not model_name.startswith("publishers/google/models/")
296-
and not model_name.startswith("publishers/meta/models/")
297-
and not model_name.startswith("publishers/anthropic/models/")
298-
# tuned model full resource name
299-
and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
314+
re.match(
315+
r"^publishers/(?P<publisher>[^/]+)/models/(?P<model>[^@]+)@(?P<version>[^@]+)$",
316+
model_name,
317+
)
318+
or model_name.startswith("publishers/google/models/")
319+
or model_name.startswith("publishers/meta/models/")
320+
or model_name.startswith("publishers/anthropic/models/")
321+
or re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
300322
):
323+
return model_name
324+
else:
301325
raise ValueError(f"Invalid format for model name: {model_name}.")
302326

303-
return model_name
304-
305327
@classmethod
306328
def _is_genai_model(cls, model_name: str) -> bool:
307329
"""Validates if a given model_name represents a GenAI model."""
@@ -326,6 +348,13 @@ def _is_genai_model(cls, model_name: str) -> bool:
326348
# Model is a claude model.
327349
return True
328350

351+
if re.match(
352+
r"^publishers/(?P<publisher>[^/]+)/models/(?P<model>[^@]+)@(?P<version>[^@]+)$",
353+
model_name,
354+
):
355+
# Model is a self-hosted model.
356+
return True
357+
329358
return False
330359

331360
@classmethod

0 commit comments

Comments
 (0)