17
17
"""Unit tests for generative model batch prediction."""
18
18
# pylint: disable=protected-access
19
19
20
+ import importlib
20
21
import pytest
21
22
from unittest import mock
22
23
24
+ from google .cloud import aiplatform
23
25
import vertexai
24
26
from google .cloud .aiplatform import base as aiplatform_base
25
27
from google .cloud .aiplatform import initializer as aiplatform_initializer
26
28
from google .cloud .aiplatform .compat .services import job_service_client
27
29
from google .cloud .aiplatform .compat .types import (
28
30
batch_prediction_job as gca_batch_prediction_job_compat ,
29
- )
30
- from google .cloud .aiplatform .compat .types import (
31
+ io as gca_io_compat ,
31
32
job_state as gca_job_state_compat ,
32
33
)
33
34
from vertexai .preview import batch_prediction
35
+ from vertexai .generative_models import GenerativeModel
34
36
35
37
36
38
_TEST_PROJECT = "test-project"
37
39
_TEST_LOCATION = "us-central1"
40
+ _TEST_BUCKET = "gs://test-bucket"
41
+ _TEST_PARENT = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } "
42
+ _TEST_DISPLAY_NAME = "test-display-name"
38
43
39
44
_TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro"
40
45
_TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{ _TEST_GEMINI_MODEL_NAME } "
41
46
_TEST_PALM_MODEL_NAME = "text-bison"
42
47
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{ _TEST_PALM_MODEL_NAME } "
43
48
49
+ _TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
50
+ _TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
51
+ _TEST_GCS_OUTPUT_PREFIX = "gs://test-bucket/test-output"
52
+ _TEST_BQ_INPUT_URI = "bq://test-project.test-dataset.test-input"
53
+ _TEST_BQ_OUTPUT_PREFIX = "bq://test-project.test-dataset.test-output"
54
+ _TEST_INVALID_URI = "invalid-uri"
55
+
56
+
44
57
_TEST_BATCH_PREDICTION_JOB_ID = "123456789"
45
- _TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /batchPredictionJobs/{ _TEST_BATCH_PREDICTION_JOB_ID } "
58
+ _TEST_BATCH_PREDICTION_JOB_NAME = (
59
+ f"{ _TEST_PARENT } /batchPredictionJobs/{ _TEST_BATCH_PREDICTION_JOB_ID } "
60
+ )
61
+ _TEST_JOB_STATE_RUNNING = gca_job_state_compat .JobState (3 )
46
62
_TEST_JOB_STATE_SUCCESS = gca_job_state_compat .JobState (4 )
47
63
64
+ _TEST_GAPIC_BATCH_PREDICTION_JOB = gca_batch_prediction_job_compat .BatchPredictionJob (
65
+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
66
+ display_name = _TEST_DISPLAY_NAME ,
67
+ model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
68
+ state = _TEST_JOB_STATE_RUNNING ,
69
+ )
70
+
48
71
49
72
# TODO(b/339230025) Mock the whole service instead of methods.
73
+ @pytest .fixture
74
+ def generate_display_name_mock ():
75
+ with mock .patch .object (
76
+ aiplatform_base .VertexAiResourceNoun , "_generate_display_name"
77
+ ) as generate_display_name_mock :
78
+ generate_display_name_mock .return_value = _TEST_DISPLAY_NAME
79
+ yield generate_display_name_mock
80
+
81
+
82
+ @pytest .fixture
83
+ def complete_bq_uri_mock ():
84
+ with mock .patch .object (
85
+ batch_prediction .BatchPredictionJob , "_complete_bq_uri"
86
+ ) as complete_bq_uri_mock :
87
+ complete_bq_uri_mock .return_value = _TEST_BQ_OUTPUT_PREFIX
88
+ yield complete_bq_uri_mock
89
+
90
+
50
91
@pytest .fixture
51
92
def get_batch_prediction_job_mock ():
52
93
with mock .patch .object (
53
94
job_service_client .JobServiceClient , "get_batch_prediction_job"
54
95
) as get_job_mock :
55
- get_job_mock .return_value = gca_batch_prediction_job_compat .BatchPredictionJob (
56
- name = _TEST_BATCH_PREDICTION_JOB_NAME ,
57
- model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
58
- state = _TEST_JOB_STATE_SUCCESS ,
59
- )
96
+ get_job_mock .return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
60
97
yield get_job_mock
61
98
62
99
@@ -67,17 +104,32 @@ def get_batch_prediction_job_invalid_model_mock():
67
104
) as get_job_mock :
68
105
get_job_mock .return_value = gca_batch_prediction_job_compat .BatchPredictionJob (
69
106
name = _TEST_BATCH_PREDICTION_JOB_NAME ,
107
+ display_name = _TEST_DISPLAY_NAME ,
70
108
model = _TEST_PALM_MODEL_RESOURCE_NAME ,
71
109
state = _TEST_JOB_STATE_SUCCESS ,
72
110
)
73
111
yield get_job_mock
74
112
75
113
76
- @pytest .mark .usefixtures ("google_auth_mock" )
114
+ @pytest .fixture
115
+ def create_batch_prediction_job_mock ():
116
+ with mock .patch .object (
117
+ job_service_client .JobServiceClient , "create_batch_prediction_job"
118
+ ) as create_job_mock :
119
+ create_job_mock .return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
120
+ yield create_job_mock
121
+
122
+
123
+ @pytest .mark .usefixtures (
124
+ "google_auth_mock" , "generate_display_name_mock" , "complete_bq_uri_mock"
125
+ )
77
126
class TestBatchPredictionJob :
78
127
"""Unit tests for BatchPredictionJob."""
79
128
80
129
def setup_method (self ):
130
+ importlib .reload (aiplatform_initializer )
131
+ importlib .reload (aiplatform )
132
+ importlib .reload (vertexai )
81
133
vertexai .init (
82
134
project = _TEST_PROJECT ,
83
135
location = _TEST_LOCATION ,
@@ -104,3 +156,196 @@ def test_init_batch_prediction_job_invalid_model(self):
104
156
),
105
157
):
106
158
batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
159
+
160
+ def test_submit_batch_prediction_job_with_gcs_input (
161
+ self , create_batch_prediction_job_mock
162
+ ):
163
+ job = batch_prediction .BatchPredictionJob .submit (
164
+ source_model = _TEST_GEMINI_MODEL_NAME ,
165
+ input_dataset = _TEST_GCS_INPUT_URI ,
166
+ output_uri_prefix = _TEST_GCS_OUTPUT_PREFIX ,
167
+ )
168
+
169
+ assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
170
+
171
+ expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat .BatchPredictionJob (
172
+ display_name = _TEST_DISPLAY_NAME ,
173
+ model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
174
+ input_config = gca_batch_prediction_job_compat .BatchPredictionJob .InputConfig (
175
+ instances_format = "jsonl" ,
176
+ gcs_source = gca_io_compat .GcsSource (uris = [_TEST_GCS_INPUT_URI ]),
177
+ ),
178
+ output_config = gca_batch_prediction_job_compat .BatchPredictionJob .OutputConfig (
179
+ gcs_destination = gca_io_compat .GcsDestination (
180
+ output_uri_prefix = _TEST_GCS_OUTPUT_PREFIX
181
+ ),
182
+ predictions_format = "jsonl" ,
183
+ ),
184
+ )
185
+ create_batch_prediction_job_mock .assert_called_once_with (
186
+ parent = _TEST_PARENT ,
187
+ batch_prediction_job = expected_gapic_batch_prediction_job ,
188
+ timeout = None ,
189
+ )
190
+
191
+ def test_submit_batch_prediction_job_with_bq_input (
192
+ self , create_batch_prediction_job_mock
193
+ ):
194
+ job = batch_prediction .BatchPredictionJob .submit (
195
+ source_model = _TEST_GEMINI_MODEL_NAME ,
196
+ input_dataset = _TEST_BQ_INPUT_URI ,
197
+ output_uri_prefix = _TEST_BQ_OUTPUT_PREFIX ,
198
+ )
199
+
200
+ assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
201
+
202
+ expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat .BatchPredictionJob (
203
+ display_name = _TEST_DISPLAY_NAME ,
204
+ model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
205
+ input_config = gca_batch_prediction_job_compat .BatchPredictionJob .InputConfig (
206
+ instances_format = "bigquery" ,
207
+ bigquery_source = gca_io_compat .BigQuerySource (
208
+ input_uri = _TEST_BQ_INPUT_URI
209
+ ),
210
+ ),
211
+ output_config = gca_batch_prediction_job_compat .BatchPredictionJob .OutputConfig (
212
+ bigquery_destination = gca_io_compat .BigQueryDestination (
213
+ output_uri = _TEST_BQ_OUTPUT_PREFIX
214
+ ),
215
+ predictions_format = "bigquery" ,
216
+ ),
217
+ )
218
+ create_batch_prediction_job_mock .assert_called_once_with (
219
+ parent = _TEST_PARENT ,
220
+ batch_prediction_job = expected_gapic_batch_prediction_job ,
221
+ timeout = None ,
222
+ )
223
+
224
+ def test_submit_batch_prediction_job_with_gcs_input_without_output_uri_prefix (
225
+ self , create_batch_prediction_job_mock
226
+ ):
227
+ vertexai .init (staging_bucket = _TEST_BUCKET )
228
+ model = GenerativeModel (_TEST_GEMINI_MODEL_NAME )
229
+ job = batch_prediction .BatchPredictionJob .submit (
230
+ source_model = model ,
231
+ input_dataset = [_TEST_GCS_INPUT_URI , _TEST_GCS_INPUT_URI_2 ],
232
+ )
233
+
234
+ assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
235
+
236
+ expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat .BatchPredictionJob (
237
+ display_name = _TEST_DISPLAY_NAME ,
238
+ model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
239
+ input_config = gca_batch_prediction_job_compat .BatchPredictionJob .InputConfig (
240
+ instances_format = "jsonl" ,
241
+ gcs_source = gca_io_compat .GcsSource (
242
+ uris = [_TEST_GCS_INPUT_URI , _TEST_GCS_INPUT_URI_2 ]
243
+ ),
244
+ ),
245
+ output_config = gca_batch_prediction_job_compat .BatchPredictionJob .OutputConfig (
246
+ gcs_destination = gca_io_compat .GcsDestination (
247
+ output_uri_prefix = f"{ _TEST_BUCKET } /gen-ai-batch-prediction"
248
+ ),
249
+ predictions_format = "jsonl" ,
250
+ ),
251
+ )
252
+ create_batch_prediction_job_mock .assert_called_once_with (
253
+ parent = _TEST_PARENT ,
254
+ batch_prediction_job = expected_gapic_batch_prediction_job ,
255
+ timeout = None ,
256
+ )
257
+
258
+ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix (
259
+ self , create_batch_prediction_job_mock
260
+ ):
261
+ model = GenerativeModel (_TEST_GEMINI_MODEL_NAME )
262
+ job = batch_prediction .BatchPredictionJob .submit (
263
+ source_model = model ,
264
+ input_dataset = _TEST_BQ_INPUT_URI ,
265
+ )
266
+
267
+ assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
268
+
269
+ expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat .BatchPredictionJob (
270
+ display_name = _TEST_DISPLAY_NAME ,
271
+ model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
272
+ input_config = gca_batch_prediction_job_compat .BatchPredictionJob .InputConfig (
273
+ instances_format = "bigquery" ,
274
+ bigquery_source = gca_io_compat .BigQuerySource (
275
+ input_uri = _TEST_BQ_INPUT_URI
276
+ ),
277
+ ),
278
+ output_config = gca_batch_prediction_job_compat .BatchPredictionJob .OutputConfig (
279
+ bigquery_destination = gca_io_compat .BigQueryDestination (
280
+ output_uri = _TEST_BQ_OUTPUT_PREFIX
281
+ ),
282
+ predictions_format = "bigquery" ,
283
+ ),
284
+ )
285
+ create_batch_prediction_job_mock .assert_called_once_with (
286
+ parent = _TEST_PARENT ,
287
+ batch_prediction_job = expected_gapic_batch_prediction_job ,
288
+ timeout = None ,
289
+ )
290
+
291
+ def test_submit_batch_prediction_job_with_invalid_source_model (self ):
292
+ with pytest .raises (
293
+ ValueError ,
294
+ match = (f"Model '{ _TEST_PALM_MODEL_RESOURCE_NAME } ' is not a GenAI model." ),
295
+ ):
296
+ batch_prediction .BatchPredictionJob .submit (
297
+ source_model = _TEST_PALM_MODEL_NAME ,
298
+ input_dataset = _TEST_GCS_INPUT_URI ,
299
+ )
300
+
301
+ def test_submit_batch_prediction_job_with_invalid_input_dataset (self ):
302
+ with pytest .raises (
303
+ ValueError ,
304
+ match = (
305
+ f"Unsupported input URI: { _TEST_INVALID_URI } . "
306
+ "Supported formats: 'gs://path/to/input/data.jsonl' and "
307
+ "'bq://projectId.bqDatasetId.bqTableId'"
308
+ ),
309
+ ):
310
+ batch_prediction .BatchPredictionJob .submit (
311
+ source_model = _TEST_GEMINI_MODEL_NAME ,
312
+ input_dataset = _TEST_INVALID_URI ,
313
+ )
314
+
315
+ invalid_bq_uris = ["bq://projectId.dataset1" , "bq://projectId.dataset2" ]
316
+ with pytest .raises (
317
+ ValueError ,
318
+ match = ("Multiple Bigquery input datasets are not supported." ),
319
+ ):
320
+ batch_prediction .BatchPredictionJob .submit (
321
+ source_model = _TEST_GEMINI_MODEL_NAME ,
322
+ input_dataset = invalid_bq_uris ,
323
+ )
324
+
325
+ def test_submit_batch_prediction_job_with_invalid_output_uri_prefix (self ):
326
+ with pytest .raises (
327
+ ValueError ,
328
+ match = (
329
+ f"Unsupported output URI: { _TEST_INVALID_URI } . "
330
+ "Supported formats: 'gs://path/to/output/data' and "
331
+ "'bq://projectId.bqDatasetId'"
332
+ ),
333
+ ):
334
+ batch_prediction .BatchPredictionJob .submit (
335
+ source_model = _TEST_GEMINI_MODEL_NAME ,
336
+ input_dataset = _TEST_GCS_INPUT_URI ,
337
+ output_uri_prefix = _TEST_INVALID_URI ,
338
+ )
339
+
340
+ def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket (self ):
341
+ with pytest .raises (
342
+ ValueError ,
343
+ match = (
344
+ "Please either specify output_uri_prefix or "
345
+ "set staging_bucket in vertexai.init()."
346
+ ),
347
+ ):
348
+ batch_prediction .BatchPredictionJob .submit (
349
+ source_model = _TEST_GEMINI_MODEL_NAME ,
350
+ input_dataset = _TEST_GCS_INPUT_URI ,
351
+ )
0 commit comments