Skip to content

Commit 4d091c6

Browse files
jaycee-licopybara-github
authored andcommitted
feat: GenAI - Added the BatchPredictionJob.submit method
PiperOrigin-RevId: 633686787
1 parent df4a4f2 commit 4d091c6

File tree

2 files changed

+427
-11
lines changed

2 files changed

+427
-11
lines changed

tests/unit/vertexai/test_batch_prediction.py

Lines changed: 254 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,83 @@
1717
"""Unit tests for generative model batch prediction."""
1818
# pylint: disable=protected-access
1919

20+
import importlib
2021
import pytest
2122
from unittest import mock
2223

24+
from google.cloud import aiplatform
2325
import vertexai
2426
from google.cloud.aiplatform import base as aiplatform_base
2527
from google.cloud.aiplatform import initializer as aiplatform_initializer
2628
from google.cloud.aiplatform.compat.services import job_service_client
2729
from google.cloud.aiplatform.compat.types import (
2830
batch_prediction_job as gca_batch_prediction_job_compat,
29-
)
30-
from google.cloud.aiplatform.compat.types import (
31+
io as gca_io_compat,
3132
job_state as gca_job_state_compat,
3233
)
3334
from vertexai.preview import batch_prediction
35+
from vertexai.generative_models import GenerativeModel
3436

3537

3638
_TEST_PROJECT = "test-project"
3739
_TEST_LOCATION = "us-central1"
40+
_TEST_BUCKET = "gs://test-bucket"
41+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
42+
_TEST_DISPLAY_NAME = "test-display-name"
3843

3944
_TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro"
4045
_TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_GEMINI_MODEL_NAME}"
4146
_TEST_PALM_MODEL_NAME = "text-bison"
4247
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}"
4348

49+
_TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
50+
_TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
51+
_TEST_GCS_OUTPUT_PREFIX = "gs://test-bucket/test-output"
52+
_TEST_BQ_INPUT_URI = "bq://test-project.test-dataset.test-input"
53+
_TEST_BQ_OUTPUT_PREFIX = "bq://test-project.test-dataset.test-output"
54+
_TEST_INVALID_URI = "invalid-uri"
55+
56+
4457
_TEST_BATCH_PREDICTION_JOB_ID = "123456789"
45-
_TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_BATCH_PREDICTION_JOB_ID}"
58+
_TEST_BATCH_PREDICTION_JOB_NAME = (
59+
f"{_TEST_PARENT}/batchPredictionJobs/{_TEST_BATCH_PREDICTION_JOB_ID}"
60+
)
61+
_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3)
4662
_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4)
4763

64+
_TEST_GAPIC_BATCH_PREDICTION_JOB = gca_batch_prediction_job_compat.BatchPredictionJob(
65+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
66+
display_name=_TEST_DISPLAY_NAME,
67+
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
68+
state=_TEST_JOB_STATE_RUNNING,
69+
)
70+
4871

4972
# TODO(b/339230025) Mock the whole service instead of methods.
73+
@pytest.fixture
74+
def generate_display_name_mock():
75+
with mock.patch.object(
76+
aiplatform_base.VertexAiResourceNoun, "_generate_display_name"
77+
) as generate_display_name_mock:
78+
generate_display_name_mock.return_value = _TEST_DISPLAY_NAME
79+
yield generate_display_name_mock
80+
81+
82+
@pytest.fixture
83+
def complete_bq_uri_mock():
84+
with mock.patch.object(
85+
batch_prediction.BatchPredictionJob, "_complete_bq_uri"
86+
) as complete_bq_uri_mock:
87+
complete_bq_uri_mock.return_value = _TEST_BQ_OUTPUT_PREFIX
88+
yield complete_bq_uri_mock
89+
90+
5091
@pytest.fixture
5192
def get_batch_prediction_job_mock():
5293
with mock.patch.object(
5394
job_service_client.JobServiceClient, "get_batch_prediction_job"
5495
) as get_job_mock:
55-
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
56-
name=_TEST_BATCH_PREDICTION_JOB_NAME,
57-
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
58-
state=_TEST_JOB_STATE_SUCCESS,
59-
)
96+
get_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
6097
yield get_job_mock
6198

6299

@@ -67,17 +104,32 @@ def get_batch_prediction_job_invalid_model_mock():
67104
) as get_job_mock:
68105
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
69106
name=_TEST_BATCH_PREDICTION_JOB_NAME,
107+
display_name=_TEST_DISPLAY_NAME,
70108
model=_TEST_PALM_MODEL_RESOURCE_NAME,
71109
state=_TEST_JOB_STATE_SUCCESS,
72110
)
73111
yield get_job_mock
74112

75113

76-
@pytest.mark.usefixtures("google_auth_mock")
114+
@pytest.fixture
115+
def create_batch_prediction_job_mock():
116+
with mock.patch.object(
117+
job_service_client.JobServiceClient, "create_batch_prediction_job"
118+
) as create_job_mock:
119+
create_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
120+
yield create_job_mock
121+
122+
123+
@pytest.mark.usefixtures(
124+
"google_auth_mock", "generate_display_name_mock", "complete_bq_uri_mock"
125+
)
77126
class TestBatchPredictionJob:
78127
"""Unit tests for BatchPredictionJob."""
79128

80129
def setup_method(self):
130+
importlib.reload(aiplatform_initializer)
131+
importlib.reload(aiplatform)
132+
importlib.reload(vertexai)
81133
vertexai.init(
82134
project=_TEST_PROJECT,
83135
location=_TEST_LOCATION,
@@ -104,3 +156,196 @@ def test_init_batch_prediction_job_invalid_model(self):
104156
),
105157
):
106158
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
159+
160+
def test_submit_batch_prediction_job_with_gcs_input(
161+
self, create_batch_prediction_job_mock
162+
):
163+
job = batch_prediction.BatchPredictionJob.submit(
164+
source_model=_TEST_GEMINI_MODEL_NAME,
165+
input_dataset=_TEST_GCS_INPUT_URI,
166+
output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX,
167+
)
168+
169+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
170+
171+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
172+
display_name=_TEST_DISPLAY_NAME,
173+
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
174+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
175+
instances_format="jsonl",
176+
gcs_source=gca_io_compat.GcsSource(uris=[_TEST_GCS_INPUT_URI]),
177+
),
178+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
179+
gcs_destination=gca_io_compat.GcsDestination(
180+
output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX
181+
),
182+
predictions_format="jsonl",
183+
),
184+
)
185+
create_batch_prediction_job_mock.assert_called_once_with(
186+
parent=_TEST_PARENT,
187+
batch_prediction_job=expected_gapic_batch_prediction_job,
188+
timeout=None,
189+
)
190+
191+
def test_submit_batch_prediction_job_with_bq_input(
192+
self, create_batch_prediction_job_mock
193+
):
194+
job = batch_prediction.BatchPredictionJob.submit(
195+
source_model=_TEST_GEMINI_MODEL_NAME,
196+
input_dataset=_TEST_BQ_INPUT_URI,
197+
output_uri_prefix=_TEST_BQ_OUTPUT_PREFIX,
198+
)
199+
200+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
201+
202+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
203+
display_name=_TEST_DISPLAY_NAME,
204+
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
205+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
206+
instances_format="bigquery",
207+
bigquery_source=gca_io_compat.BigQuerySource(
208+
input_uri=_TEST_BQ_INPUT_URI
209+
),
210+
),
211+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
212+
bigquery_destination=gca_io_compat.BigQueryDestination(
213+
output_uri=_TEST_BQ_OUTPUT_PREFIX
214+
),
215+
predictions_format="bigquery",
216+
),
217+
)
218+
create_batch_prediction_job_mock.assert_called_once_with(
219+
parent=_TEST_PARENT,
220+
batch_prediction_job=expected_gapic_batch_prediction_job,
221+
timeout=None,
222+
)
223+
224+
def test_submit_batch_prediction_job_with_gcs_input_without_output_uri_prefix(
225+
self, create_batch_prediction_job_mock
226+
):
227+
vertexai.init(staging_bucket=_TEST_BUCKET)
228+
model = GenerativeModel(_TEST_GEMINI_MODEL_NAME)
229+
job = batch_prediction.BatchPredictionJob.submit(
230+
source_model=model,
231+
input_dataset=[_TEST_GCS_INPUT_URI, _TEST_GCS_INPUT_URI_2],
232+
)
233+
234+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
235+
236+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
237+
display_name=_TEST_DISPLAY_NAME,
238+
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
239+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
240+
instances_format="jsonl",
241+
gcs_source=gca_io_compat.GcsSource(
242+
uris=[_TEST_GCS_INPUT_URI, _TEST_GCS_INPUT_URI_2]
243+
),
244+
),
245+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
246+
gcs_destination=gca_io_compat.GcsDestination(
247+
output_uri_prefix=f"{_TEST_BUCKET}/gen-ai-batch-prediction"
248+
),
249+
predictions_format="jsonl",
250+
),
251+
)
252+
create_batch_prediction_job_mock.assert_called_once_with(
253+
parent=_TEST_PARENT,
254+
batch_prediction_job=expected_gapic_batch_prediction_job,
255+
timeout=None,
256+
)
257+
258+
def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix(
259+
self, create_batch_prediction_job_mock
260+
):
261+
model = GenerativeModel(_TEST_GEMINI_MODEL_NAME)
262+
job = batch_prediction.BatchPredictionJob.submit(
263+
source_model=model,
264+
input_dataset=_TEST_BQ_INPUT_URI,
265+
)
266+
267+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
268+
269+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
270+
display_name=_TEST_DISPLAY_NAME,
271+
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
272+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
273+
instances_format="bigquery",
274+
bigquery_source=gca_io_compat.BigQuerySource(
275+
input_uri=_TEST_BQ_INPUT_URI
276+
),
277+
),
278+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
279+
bigquery_destination=gca_io_compat.BigQueryDestination(
280+
output_uri=_TEST_BQ_OUTPUT_PREFIX
281+
),
282+
predictions_format="bigquery",
283+
),
284+
)
285+
create_batch_prediction_job_mock.assert_called_once_with(
286+
parent=_TEST_PARENT,
287+
batch_prediction_job=expected_gapic_batch_prediction_job,
288+
timeout=None,
289+
)
290+
291+
def test_submit_batch_prediction_job_with_invalid_source_model(self):
292+
with pytest.raises(
293+
ValueError,
294+
match=(f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a GenAI model."),
295+
):
296+
batch_prediction.BatchPredictionJob.submit(
297+
source_model=_TEST_PALM_MODEL_NAME,
298+
input_dataset=_TEST_GCS_INPUT_URI,
299+
)
300+
301+
def test_submit_batch_prediction_job_with_invalid_input_dataset(self):
302+
with pytest.raises(
303+
ValueError,
304+
match=(
305+
f"Unsupported input URI: {_TEST_INVALID_URI}. "
306+
"Supported formats: 'gs://path/to/input/data.jsonl' and "
307+
"'bq://projectId.bqDatasetId.bqTableId'"
308+
),
309+
):
310+
batch_prediction.BatchPredictionJob.submit(
311+
source_model=_TEST_GEMINI_MODEL_NAME,
312+
input_dataset=_TEST_INVALID_URI,
313+
)
314+
315+
invalid_bq_uris = ["bq://projectId.dataset1", "bq://projectId.dataset2"]
316+
with pytest.raises(
317+
ValueError,
318+
match=("Multiple Bigquery input datasets are not supported."),
319+
):
320+
batch_prediction.BatchPredictionJob.submit(
321+
source_model=_TEST_GEMINI_MODEL_NAME,
322+
input_dataset=invalid_bq_uris,
323+
)
324+
325+
def test_submit_batch_prediction_job_with_invalid_output_uri_prefix(self):
326+
with pytest.raises(
327+
ValueError,
328+
match=(
329+
f"Unsupported output URI: {_TEST_INVALID_URI}. "
330+
"Supported formats: 'gs://path/to/output/data' and "
331+
"'bq://projectId.bqDatasetId'"
332+
),
333+
):
334+
batch_prediction.BatchPredictionJob.submit(
335+
source_model=_TEST_GEMINI_MODEL_NAME,
336+
input_dataset=_TEST_GCS_INPUT_URI,
337+
output_uri_prefix=_TEST_INVALID_URI,
338+
)
339+
340+
def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket(self):
341+
with pytest.raises(
342+
ValueError,
343+
match=(
344+
"Please either specify output_uri_prefix or "
345+
"set staging_bucket in vertexai.init()."
346+
),
347+
):
348+
batch_prediction.BatchPredictionJob.submit(
349+
source_model=_TEST_GEMINI_MODEL_NAME,
350+
input_dataset=_TEST_GCS_INPUT_URI,
351+
)

0 commit comments

Comments
 (0)