Skip to content

Commit 2690e72

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: LVM - Added support for GCS storage.googleapis.com URL import in vision_models.Image
PiperOrigin-RevId: 612948528
1 parent 9eb5a52 commit 2690e72

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

tests/unit/aiplatform/test_vision_models.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,12 @@ def generate_image_from_gcs_uri(
158158
return ga_vision_models.Image.load_from_file(gcs_uri)
159159

160160

161+
def generate_image_from_storage_url(
162+
gcs_uri: str = "https://storage.googleapis.com/cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
163+
) -> ga_vision_models.Image:
164+
return ga_vision_models.Image.load_from_file(gcs_uri)
165+
166+
161167
def generate_video_from_gcs_uri(
162168
gcs_uri: str = "gs://cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4",
163169
) -> ga_vision_models.Video:
@@ -894,6 +900,46 @@ def test_image_embedding_model_with_gcs_uri(self):
894900
assert embedding_response.image_embedding == test_embeddings
895901
assert embedding_response.text_embedding == test_embeddings
896902

903+
def test_image_embedding_model_with_storage_url(self):
904+
aiplatform.init(
905+
project=_TEST_PROJECT,
906+
location=_TEST_LOCATION,
907+
)
908+
with mock.patch.object(
909+
target=model_garden_service_client.ModelGardenServiceClient,
910+
attribute="get_publisher_model",
911+
return_value=gca_publisher_model.PublisherModel(
912+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
913+
),
914+
):
915+
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
916+
"multimodalembedding@001"
917+
)
918+
919+
test_embeddings = [0, 0]
920+
gca_predict_response = gca_prediction_service.PredictResponse()
921+
gca_predict_response.predictions.append(
922+
{"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings}
923+
)
924+
925+
image = generate_image_from_storage_url()
926+
assert (
927+
image._gcs_uri
928+
== "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
929+
)
930+
931+
with mock.patch.object(
932+
target=prediction_service_client.PredictionServiceClient,
933+
attribute="predict",
934+
return_value=gca_predict_response,
935+
):
936+
embedding_response = model.get_embeddings(
937+
image=image, contextual_text="hello world"
938+
)
939+
940+
assert embedding_response.image_embedding == test_embeddings
941+
assert embedding_response.text_embedding == test_embeddings
942+
897943
def test_video_embedding_model_with_only_video(self):
898944
aiplatform.init(
899945
project=_TEST_PROJECT,

vertexai/vision_models/_vision_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pathlib
2323
import typing
2424
from typing import Any, Dict, List, Optional, Union
25+
import urllib
2526

2627
from google.cloud import storage
2728

@@ -80,9 +81,20 @@ def load_from_file(location: str) -> "Image":
8081
Returns:
8182
Loaded image as an `Image` object.
8283
"""
83-
if location.startswith("gs://"):
84+
parsed_url = urllib.parse.urlparse(location)
85+
if (
86+
parsed_url.scheme == "https"
87+
and parsed_url.netloc == "storage.googleapis.com"
88+
):
89+
parsed_url = parsed_url._replace(
90+
scheme="gs", netloc="", path=f"/{urllib.parse.unquote(parsed_url.path)}"
91+
)
92+
location = urllib.parse.urlunparse(parsed_url)
93+
94+
if parsed_url.scheme == "gs":
8495
return Image(gcs_uri=location)
8596

97+
# Load image from local path
8698
image_bytes = pathlib.Path(location).read_bytes()
8799
image = Image(image_bytes=image_bytes)
88100
return image

0 commit comments

Comments
 (0)