Skip to content

Commit 6c1f2cc

Browse files
lingyinwcopybara-github
authored andcommitted
feat: add numeric_restricts to MatchingEngineIndex find_neighbors() for querying
public endpoints. PiperOrigin-RevId: 582893369
1 parent 3a8f22c commit 6c1f2cc

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,91 @@ class Namespace:
7676
deny_tokens: list = field(default_factory=list)
7777

7878

79+
@dataclass
80+
class NumericNamespace:
81+
"""NumericNamespace specifies the rules for determining the datapoints that
82+
are eligible for each matching query, overall query is an AND across namespaces.
83+
This uses numeric comparisons.
84+
85+
Args:
86+
name (str):
87+
Required. The name of this numeric namespace.
88+
value_int (int):
89+
Optional. 64 bit integer value for comparison. Must choose one among
90+
`value_int`, `value_float` and `value_double` for intended
91+
precision.
92+
value_float (float):
93+
Optional. 32 bit float value for comparison. Must choose one among
94+
`value_int`, `value_float` and `value_double` for
95+
intended precision.
96+
value_double (float):
97+
Optional. 64b bit float value for comparison. Must choose one among
98+
`value_int`, `value_float` and `value_double` for
99+
intended precision.
100+
operator (str):
101+
Optional. Should be specified for query only, not for a datapoints.
102+
Specify one operator to use for comparison. Datapoints for which
103+
comparisons with query's values are true for the operator and value
104+
combination will be allowlisted. Choose among:
105+
"LESS" for datapoints' values < query's value;
106+
"LESS_EQUAL" for datapoints' values <= query's value;
107+
"EQUAL" for datapoints' values = query's value;
108+
"GREATER_EQUAL" for datapoints' values >= query's value;
109+
"GREATER" for datapoints' values > query's value;
110+
"""
111+
112+
name: str
113+
value_int: Optional[int] = None
114+
value_float: Optional[float] = None
115+
value_double: Optional[float] = None
116+
op: Optional[str] = None
117+
118+
def __post_init__(self):
119+
"""Check NumericNamespace values are of correct types and values are
120+
not all none.
121+
Args:
122+
None.
123+
124+
Raises:
125+
ValueError: Numeric Namespace provided values must be of correct
126+
types and one of value_int, value_float, value_double must exist.
127+
"""
128+
# Check one of
129+
if (
130+
self.value_int is None
131+
and self.value_float is None
132+
and self.value_double is None
133+
):
134+
raise ValueError(
135+
"Must choose one among `value_int`,"
136+
"`value_float` and `value_double` for "
137+
"intended precision."
138+
)
139+
140+
# Check value type
141+
if self.value_int is not None and not isinstance(self.value_int, int):
142+
raise ValueError(
143+
"value_int must be of type int, got" f" { type(self.value_int)}."
144+
)
145+
if self.value_float is not None and not isinstance(self.value_float, float):
146+
raise ValueError(
147+
"value_float must be of type float, got " f"{ type(self.value_float)}."
148+
)
149+
if self.value_double is not None and not isinstance(self.value_double, float):
150+
raise ValueError(
151+
"value_double must be of type float, got "
152+
f"{ type(self.value_double)}."
153+
)
154+
# Check operator validity
155+
if (
156+
self.op
157+
not in gca_index_v1beta1.IndexDatapoint.NumericRestriction.Operator._member_names_
158+
):
159+
raise ValueError(
160+
f"Invalid operator '{self.op}'," " must be one of the valid operators."
161+
)
162+
163+
79164
class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
80165
"""Matching Engine index endpoint resource for Vertex AI."""
81166

@@ -1034,6 +1119,7 @@ def find_neighbors(
10341119
approx_num_neighbors: Optional[int] = None,
10351120
fraction_leaf_nodes_to_search_override: Optional[float] = None,
10361121
return_full_datapoint: bool = False,
1122+
numeric_filter: Optional[List[NumericNamespace]] = [],
10371123
) -> List[List[MatchNeighbor]]:
10381124
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
10391125
@@ -1082,6 +1168,11 @@ def find_neighbors(
10821168
Note that returning full datapoint will significantly increase the
10831169
latency and cost of the query.
10841170
1171+
numeric_filter (Optional[list[NumericNamespace]]):
1172+
Optional. A list of NumericNamespaces for filtering the matching
1173+
results. For example:
1174+
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
1175+
will match datapoints that its cost is greater than 5.
10851176
Returns:
10861177
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
10871178
"""
@@ -1110,12 +1201,22 @@ def find_neighbors(
11101201
fraction_leaf_nodes_to_search_override
11111202
)
11121203
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
1204+
# Token restricts
11131205
for namespace in filter:
11141206
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
11151207
restrict.namespace = namespace.name
11161208
restrict.allow_list.extend(namespace.allow_tokens)
11171209
restrict.deny_list.extend(namespace.deny_tokens)
11181210
datapoint.restricts.append(restrict)
1211+
# Numeric restricts
1212+
for numeric_namespace in numeric_filter:
1213+
numeric_restrict = gca_index_v1beta1.IndexDatapoint.NumericRestriction()
1214+
numeric_restrict.namespace = numeric_namespace.name
1215+
numeric_restrict.op = numeric_namespace.op
1216+
numeric_restrict.value_int = numeric_namespace.value_int
1217+
numeric_restrict.value_float = numeric_namespace.value_float
1218+
numeric_restrict.value_double = numeric_namespace.value_double
1219+
datapoint.numeric_restricts.append(numeric_restrict)
11191220
find_neighbors_query.datapoint = datapoint
11201221
find_neighbors_request.queries.append(find_neighbors_query)
11211222

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
2828
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
2929
Namespace,
30+
NumericNamespace,
3031
)
3132
from google.cloud.aiplatform.compat.types import (
3233
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
@@ -233,6 +234,11 @@
233234
_TEST_FILTER = [
234235
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
235236
]
237+
_TEST_NUMERIC_FILTER = [
238+
NumericNamespace(name="cost", value_double=0.3, op="EQUAL"),
239+
NumericNamespace(name="size", value_int=10, op="GREATER"),
240+
NumericNamespace(name="seconds", value_float=20.5, op="LESS_EQUAL"),
241+
]
236242
_TEST_IDS = ["123", "456", "789"]
237243
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
238244
_TEST_APPROX_NUM_NEIGHBORS = 2
@@ -1080,6 +1086,102 @@ def test_index_public_endpoint_match_queries(
10801086
find_neighbors_request
10811087
)
10821088

1089+
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
1090+
def test_index_public_endpoint_match_queries_with_numeric_filtering(
1091+
self, index_public_endpoint_match_queries_mock
1092+
):
1093+
aiplatform.init(project=_TEST_PROJECT)
1094+
1095+
my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1096+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1097+
)
1098+
1099+
my_pubic_index_endpoint.find_neighbors(
1100+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1101+
queries=_TEST_QUERIES,
1102+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1103+
filter=_TEST_FILTER,
1104+
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1105+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1106+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1107+
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
1108+
numeric_filter=_TEST_NUMERIC_FILTER,
1109+
)
1110+
1111+
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
1112+
index_endpoint=my_pubic_index_endpoint.resource_name,
1113+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1114+
queries=[
1115+
gca_match_service_v1beta1.FindNeighborsRequest.Query(
1116+
neighbor_count=_TEST_NUM_NEIGHBOURS,
1117+
datapoint=gca_index_v1beta1.IndexDatapoint(
1118+
feature_vector=_TEST_QUERIES[0],
1119+
restricts=[
1120+
gca_index_v1beta1.IndexDatapoint.Restriction(
1121+
namespace="class",
1122+
allow_list=["token_1"],
1123+
deny_list=["token_2"],
1124+
)
1125+
],
1126+
numeric_restricts=[
1127+
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
1128+
namespace="cost", value_double=0.3, op="EQUAL"
1129+
),
1130+
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
1131+
namespace="size", value_int=10, op="GREATER"
1132+
),
1133+
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
1134+
namespace="seconds", value_float=20.5, op="LESS_EQUAL"
1135+
),
1136+
],
1137+
),
1138+
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1139+
approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS,
1140+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1141+
)
1142+
],
1143+
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
1144+
)
1145+
1146+
index_public_endpoint_match_queries_mock.assert_called_with(
1147+
find_neighbors_request
1148+
)
1149+
1150+
def test_post_init_numeric_filter_invalid_operator_throws_exception(
1151+
self,
1152+
):
1153+
expected_message = (
1154+
"Invalid operator 'NOT_EQ', must be one of the valid operators."
1155+
)
1156+
with pytest.raises(ValueError) as exception:
1157+
NumericNamespace(name="cost", value_int=3, op="NOT_EQ")
1158+
1159+
assert str(exception.value) == expected_message
1160+
1161+
def test_post_init_numeric_namespace_missing_value_throws_exception(self):
1162+
aiplatform.init(project=_TEST_PROJECT)
1163+
1164+
expected_message = (
1165+
"Must choose one among `value_int`,"
1166+
"`value_float` and `value_double` for "
1167+
"intended precision."
1168+
)
1169+
1170+
with pytest.raises(ValueError) as exception:
1171+
NumericNamespace(name="cost", op="EQUAL")
1172+
1173+
assert str(exception.value) == expected_message
1174+
1175+
def test_index_public_endpoint_match_queries_with_numeric_filtering_value_type_mismatch_throws_exception(
1176+
self,
1177+
):
1178+
expected_message = "value_int must be of type int, got <class 'float'>."
1179+
1180+
with pytest.raises(ValueError) as exception:
1181+
NumericNamespace(name="cost", value_int=0.3, op="EQUAL")
1182+
1183+
assert str(exception.value) == expected_message
1184+
10831185
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
10841186
def test_index_public_endpoint_read_index_datapoints(
10851187
self, index_public_endpoint_read_index_datapoints_mock

0 commit comments

Comments
 (0)