Skip to content

Commit 73c0dae

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: batch_predict method generally-available at TextEmbeddingModel.
PiperOrigin-RevId: 676573556
1 parent c0626fe commit 73c0dae

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4739,6 +4739,40 @@ def test_batch_prediction_for_code_generation(self):
47394739
)
47404740

47414741
def test_batch_prediction_for_text_embedding(self):
4742+
"""Tests batch prediction."""
4743+
aiplatform.init(
4744+
project=_TEST_PROJECT,
4745+
location=_TEST_LOCATION,
4746+
)
4747+
with mock.patch.object(
4748+
target=model_garden_service_client.ModelGardenServiceClient,
4749+
attribute="get_publisher_model",
4750+
return_value=gca_publisher_model.PublisherModel(
4751+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
4752+
),
4753+
):
4754+
model = language_models.TextEmbeddingModel.from_pretrained(
4755+
"textembedding-gecko@001"
4756+
)
4757+
4758+
with mock.patch.object(
4759+
target=aiplatform.BatchPredictionJob,
4760+
attribute="create",
4761+
) as mock_create:
4762+
model.batch_predict(
4763+
dataset="gs://test-bucket/test_table.jsonl",
4764+
destination_uri_prefix="gs://test-bucket/results/",
4765+
model_parameters={},
4766+
)
4767+
mock_create.assert_called_once_with(
4768+
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001",
4769+
job_display_name=None,
4770+
gcs_source="gs://test-bucket/test_table.jsonl",
4771+
gcs_destination_prefix="gs://test-bucket/results/",
4772+
model_parameters={},
4773+
)
4774+
4775+
def test_batch_prediction_for_text_embedding_preview(self):
47424776
"""Tests batch prediction."""
47434777
aiplatform.init(
47444778
project=_TEST_PROJECT,

vertexai/language_models/_language_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,7 @@ class _TunableTextEmbeddingModelMixin(_PreviewTunableTextEmbeddingModelMixin):
24212421

24222422
class TextEmbeddingModel(
24232423
_TextEmbeddingModel,
2424+
_ModelWithBatchPredict,
24242425
_TunableTextEmbeddingModelMixin,
24252426
_CountTokensMixin,
24262427
):
@@ -2430,8 +2431,8 @@ class TextEmbeddingModel(
24302431
class _PreviewTextEmbeddingModel(
24312432
_TextEmbeddingModel,
24322433
_ModelWithBatchPredict,
2433-
_CountTokensMixin,
24342434
_PreviewTunableTextEmbeddingModelMixin,
2435+
_CountTokensMixin,
24352436
):
24362437
__name__ = "TextEmbeddingModel"
24372438
__module__ = "vertexai.preview.language_models"

0 commit comments

Comments
 (0)