Skip to content

Commit c1cb33f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add filter to Model Registry list_versions API.
PiperOrigin-RevId: 501046171
1 parent fa434e0 commit c1cb33f

File tree

2 files changed

+65
-12
lines changed

2 files changed

+65
-12
lines changed

google/cloud/aiplatform/models.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -4796,9 +4796,22 @@ def get_model(
47964796

47974797
def list_versions(
47984798
self,
4799+
filter: Optional[str] = None,
47994800
) -> List[VersionInfo]:
48004801
"""Lists the versions and version info of a model.
48014802
4803+
Args:
4804+
filter (str):
4805+
Optional. An expression for filtering the results of the request.
4806+
For field names both snake_case and camelCase are supported.
4807+
- `labels` supports general map functions that is:
4808+
- `labels.key=value` - key:value equality
4809+
- `labels.key:* or labels:key - key existence
4810+
- A key including a space must be quoted.
4811+
`labels."a key"`.
4812+
Some examples:
4813+
- `labels.myKey="myValue"`
4814+
48024815
Returns:
48034816
List[VersionInfo]:
48044817
A list of VersionInfo, each containing
@@ -4807,8 +4820,13 @@ def list_versions(
48074820

48084821
_LOGGER.info(f"Getting versions for {self.model_resource_name}")
48094822

4810-
page_result = self.client.list_model_versions(
4823+
request = gca_model_service_compat.ListModelVersionsRequest(
48114824
name=self.model_resource_name,
4825+
filter=filter,
4826+
)
4827+
4828+
page_result = self.client.list_model_versions(
4829+
request=request,
48124830
)
48134831

48144832
versions = [

tests/unit/aiplatform/test_models.py

+46-11
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@
255255
_TEST_VERSION_ID = "2"
256256
_TEST_VERSION_ALIAS_1 = "myalias"
257257
_TEST_VERSION_ALIAS_2 = "youralias"
258-
_TEST_MODEL_VERSION_DESCRIPTION = "My version description"
258+
_TEST_MODEL_VERSION_DESCRIPTION_1 = "My version 1 description"
259+
_TEST_MODEL_VERSION_DESCRIPTION_2 = "My version 2 description"
260+
_TEST_MODEL_VERSION_DESCRIPTION_3 = "My version 3 description"
259261

260262
_TEST_MODEL_VERSIONS_LIST = [
261263
gca_model.Model(
@@ -265,7 +267,7 @@
265267
display_name=_TEST_MODEL_NAME,
266268
name=f"{_TEST_MODEL_PARENT}@1",
267269
version_aliases=["default"],
268-
version_description=_TEST_MODEL_VERSION_DESCRIPTION,
270+
version_description=_TEST_MODEL_VERSION_DESCRIPTION_1,
269271
),
270272
gca_model.Model(
271273
version_id="2",
@@ -274,7 +276,7 @@
274276
display_name=_TEST_MODEL_NAME,
275277
name=f"{_TEST_MODEL_PARENT}@2",
276278
version_aliases=[_TEST_VERSION_ALIAS_1, _TEST_VERSION_ALIAS_2],
277-
version_description=_TEST_MODEL_VERSION_DESCRIPTION,
279+
version_description=_TEST_MODEL_VERSION_DESCRIPTION_2,
278280
),
279281
gca_model.Model(
280282
version_id="3",
@@ -283,9 +285,11 @@
283285
display_name=_TEST_MODEL_NAME,
284286
name=f"{_TEST_MODEL_PARENT}@3",
285287
version_aliases=[],
286-
version_description=_TEST_MODEL_VERSION_DESCRIPTION,
288+
version_description=_TEST_MODEL_VERSION_DESCRIPTION_3,
289+
labels=_TEST_LABEL,
287290
),
288291
]
292+
_TEST_MODEL_VERSIONS_WITH_FILTER_LIST = [_TEST_MODEL_VERSIONS_LIST[2]]
289293

290294
_TEST_MODELS_LIST = _TEST_MODEL_VERSIONS_LIST + [
291295
gca_model.Model(
@@ -295,7 +299,7 @@
295299
display_name=_TEST_MODEL_NAME_ALT,
296300
name=_TEST_MODEL_PARENT_ALT,
297301
version_aliases=["default"],
298-
version_description=_TEST_MODEL_VERSION_DESCRIPTION,
302+
version_description=_TEST_MODEL_VERSION_DESCRIPTION_1,
299303
),
300304
]
301305

@@ -306,7 +310,7 @@
306310
display_name=_TEST_MODEL_NAME,
307311
name=f"{_TEST_MODEL_PARENT}@{_TEST_VERSION_ID}",
308312
version_aliases=[_TEST_VERSION_ALIAS_1, _TEST_VERSION_ALIAS_2],
309-
version_description=_TEST_MODEL_VERSION_DESCRIPTION,
313+
version_description=_TEST_MODEL_VERSION_DESCRIPTION_2,
310314
)
311315

312316
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}"
@@ -683,6 +687,15 @@ def list_model_versions_mock():
683687
yield list_model_versions_mock
684688

685689

690+
@pytest.fixture
691+
def list_model_versions_with_filter_mock():
692+
with mock.patch.object(
693+
model_service_client.ModelServiceClient, "list_model_versions"
694+
) as list_model_versions_mock:
695+
list_model_versions_mock.return_value = _TEST_MODEL_VERSIONS_WITH_FILTER_LIST
696+
yield list_model_versions_mock
697+
698+
686699
@pytest.fixture
687700
def list_models_mock():
688701
with mock.patch.object(
@@ -2514,7 +2527,7 @@ def test_init_with_version_in_resource_name(self, get_model_with_version):
25142527
assert model.display_name == _TEST_MODEL_NAME
25152528
assert model.resource_name == _TEST_MODEL_PARENT
25162529
assert model.version_id == _TEST_VERSION_ID
2517-
assert model.version_description == _TEST_MODEL_VERSION_DESCRIPTION
2530+
assert model.version_description == _TEST_MODEL_VERSION_DESCRIPTION_2
25182531
# The Model yielded from upload should not have a version in resource name
25192532
assert "@" not in model.resource_name
25202533
# The Model yielded from upload SHOULD have a version in the versioned resource name
@@ -2527,7 +2540,7 @@ def test_init_with_version_arg(self, get_model_with_version):
25272540
assert model.display_name == _TEST_MODEL_NAME
25282541
assert model.resource_name == _TEST_MODEL_PARENT
25292542
assert model.version_id == _TEST_VERSION_ID
2530-
assert model.version_description == _TEST_MODEL_VERSION_DESCRIPTION
2543+
assert model.version_description == _TEST_MODEL_VERSION_DESCRIPTION_2
25312544
# The Model yielded from upload should not have a version in resource name
25322545
assert "@" not in model.resource_name
25332546
# The Model yielded from upload SHOULD have a version in the versioned resource name
@@ -2584,7 +2597,7 @@ def test_upload_new_version(
25842597
"upload_request_timeout": None,
25852598
"model_id": _TEST_ID,
25862599
"parent_model": parent,
2587-
"version_description": _TEST_MODEL_VERSION_DESCRIPTION,
2600+
"version_description": _TEST_MODEL_VERSION_DESCRIPTION_2,
25882601
"version_aliases": aliases,
25892602
"is_default_version": default,
25902603
}
@@ -2610,7 +2623,7 @@ def test_upload_new_version(
26102623
assert upload_model_request.model.version_aliases == goal
26112624
assert (
26122625
upload_model_request.model.version_description
2613-
== _TEST_MODEL_VERSION_DESCRIPTION
2626+
== _TEST_MODEL_VERSION_DESCRIPTION_2
26142627
)
26152628
assert upload_model_request.parent_model == _TEST_MODEL_PARENT
26162629
assert upload_model_request.model_id == _TEST_ID
@@ -2622,7 +2635,7 @@ def test_get_model_instance_from_registry(self, get_model_with_version):
26222635
assert model.display_name == _TEST_MODEL_NAME
26232636
assert model.resource_name == _TEST_MODEL_PARENT
26242637
assert model.version_id == _TEST_VERSION_ID
2625-
assert model.version_description == _TEST_MODEL_VERSION_DESCRIPTION
2638+
assert model.version_description == _TEST_MODEL_VERSION_DESCRIPTION_2
26262639

26272640
def test_list_versions(self, list_model_versions_mock, get_model_with_version):
26282641
my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION)
@@ -2643,6 +2656,28 @@ def test_list_versions(self, list_model_versions_mock, get_model_with_version):
26432656
assert model.name.startswith(ver.model_resource_name)
26442657
assert model.name.endswith(ver.version_id)
26452658

2659+
def test_list_versions_with_filter(
2660+
self, list_model_versions_with_filter_mock, get_model_with_version
2661+
):
2662+
my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION)
2663+
versions = my_model.versioning_registry.list_versions(
2664+
filter='labels.team="experimentation"'
2665+
)
2666+
2667+
assert len(versions) == len(_TEST_MODEL_VERSIONS_WITH_FILTER_LIST)
2668+
2669+
ver = versions[0]
2670+
model = _TEST_MODEL_VERSIONS_WITH_FILTER_LIST[0]
2671+
assert ver.version_id == "3"
2672+
assert ver.version_create_time == model.version_create_time
2673+
assert ver.version_update_time == model.version_update_time
2674+
assert ver.model_display_name == model.display_name
2675+
assert ver.version_aliases == model.version_aliases
2676+
assert ver.version_description == model.version_description
2677+
2678+
assert model.name.startswith(ver.model_resource_name)
2679+
assert model.name.endswith(ver.version_id)
2680+
26462681
def test_get_version_info(self, get_model_with_version):
26472682
my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION)
26482683
ver = my_model.versioning_registry.get_version_info("2")

0 commit comments

Comments
 (0)