@@ -158,6 +158,12 @@ def generate_image_from_gcs_uri(
158
158
return ga_vision_models .Image .load_from_file (gcs_uri )
159
159
160
160
161
+ def generate_image_from_storage_url (
162
+ gcs_uri : str = "https://storage.googleapis.com/cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" ,
163
+ ) -> ga_vision_models .Image :
164
+ return ga_vision_models .Image .load_from_file (gcs_uri )
165
+
166
+
161
167
def generate_video_from_gcs_uri (
162
168
gcs_uri : str = "gs://cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4" ,
163
169
) -> ga_vision_models .Video :
@@ -894,6 +900,46 @@ def test_image_embedding_model_with_gcs_uri(self):
894
900
assert embedding_response .image_embedding == test_embeddings
895
901
assert embedding_response .text_embedding == test_embeddings
896
902
903
+ def test_image_embedding_model_with_storage_url (self ):
904
+ aiplatform .init (
905
+ project = _TEST_PROJECT ,
906
+ location = _TEST_LOCATION ,
907
+ )
908
+ with mock .patch .object (
909
+ target = model_garden_service_client .ModelGardenServiceClient ,
910
+ attribute = "get_publisher_model" ,
911
+ return_value = gca_publisher_model .PublisherModel (
912
+ _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
913
+ ),
914
+ ):
915
+ model = preview_vision_models .MultiModalEmbeddingModel .from_pretrained (
916
+ "multimodalembedding@001"
917
+ )
918
+
919
+ test_embeddings = [0 , 0 ]
920
+ gca_predict_response = gca_prediction_service .PredictResponse ()
921
+ gca_predict_response .predictions .append (
922
+ {"imageEmbedding" : test_embeddings , "textEmbedding" : test_embeddings }
923
+ )
924
+
925
+ image = generate_image_from_storage_url ()
926
+ assert (
927
+ image ._gcs_uri
928
+ == "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
929
+ )
930
+
931
+ with mock .patch .object (
932
+ target = prediction_service_client .PredictionServiceClient ,
933
+ attribute = "predict" ,
934
+ return_value = gca_predict_response ,
935
+ ):
936
+ embedding_response = model .get_embeddings (
937
+ image = image , contextual_text = "hello world"
938
+ )
939
+
940
+ assert embedding_response .image_embedding == test_embeddings
941
+ assert embedding_response .text_embedding == test_embeddings
942
+
897
943
def test_video_embedding_model_with_only_video (self ):
898
944
aiplatform .init (
899
945
project = _TEST_PROJECT ,
0 commit comments