@@ -76,6 +76,91 @@ class Namespace:
76
76
deny_tokens : list = field (default_factory = list )
77
77
78
78
79
+ @dataclass
80
+ class NumericNamespace :
81
+ """NumericNamespace specifies the rules for determining the datapoints that
82
+ are eligible for each matching query, overall query is an AND across namespaces.
83
+ This uses numeric comparisons.
84
+
85
+ Args:
86
+ name (str):
87
+ Required. The name of this numeric namespace.
88
+ value_int (int):
89
+ Optional. 64 bit integer value for comparison. Must choose one among
90
+ `value_int`, `value_float` and `value_double` for intended
91
+ precision.
92
+ value_float (float):
93
+ Optional. 32 bit float value for comparison. Must choose one among
94
+ `value_int`, `value_float` and `value_double` for
95
+ intended precision.
96
+ value_double (float):
97
+ Optional. 64b bit float value for comparison. Must choose one among
98
+ `value_int`, `value_float` and `value_double` for
99
+ intended precision.
100
+ operator (str):
101
+ Optional. Should be specified for query only, not for a datapoints.
102
+ Specify one operator to use for comparison. Datapoints for which
103
+ comparisons with query's values are true for the operator and value
104
+ combination will be allowlisted. Choose among:
105
+ "LESS" for datapoints' values < query's value;
106
+ "LESS_EQUAL" for datapoints' values <= query's value;
107
+ "EQUAL" for datapoints' values = query's value;
108
+ "GREATER_EQUAL" for datapoints' values >= query's value;
109
+ "GREATER" for datapoints' values > query's value;
110
+ """
111
+
112
+ name : str
113
+ value_int : Optional [int ] = None
114
+ value_float : Optional [float ] = None
115
+ value_double : Optional [float ] = None
116
+ op : Optional [str ] = None
117
+
118
+ def __post_init__ (self ):
119
+ """Check NumericNamespace values are of correct types and values are
120
+ not all none.
121
+ Args:
122
+ None.
123
+
124
+ Raises:
125
+ ValueError: Numeric Namespace provided values must be of correct
126
+ types and one of value_int, value_float, value_double must exist.
127
+ """
128
+ # Check one of
129
+ if (
130
+ self .value_int is None
131
+ and self .value_float is None
132
+ and self .value_double is None
133
+ ):
134
+ raise ValueError (
135
+ "Must choose one among `value_int`,"
136
+ "`value_float` and `value_double` for "
137
+ "intended precision."
138
+ )
139
+
140
+ # Check value type
141
+ if self .value_int is not None and not isinstance (self .value_int , int ):
142
+ raise ValueError (
143
+ "value_int must be of type int, got" f" { type (self .value_int )} ."
144
+ )
145
+ if self .value_float is not None and not isinstance (self .value_float , float ):
146
+ raise ValueError (
147
+ "value_float must be of type float, got " f"{ type (self .value_float )} ."
148
+ )
149
+ if self .value_double is not None and not isinstance (self .value_double , float ):
150
+ raise ValueError (
151
+ "value_double must be of type float, got "
152
+ f"{ type (self .value_double )} ."
153
+ )
154
+ # Check operator validity
155
+ if (
156
+ self .op
157
+ not in gca_index_v1beta1 .IndexDatapoint .NumericRestriction .Operator ._member_names_
158
+ ):
159
+ raise ValueError (
160
+ f"Invalid operator '{ self .op } '," " must be one of the valid operators."
161
+ )
162
+
163
+
79
164
class MatchingEngineIndexEndpoint (base .VertexAiResourceNounWithFutureManager ):
80
165
"""Matching Engine index endpoint resource for Vertex AI."""
81
166
@@ -1034,6 +1119,7 @@ def find_neighbors(
1034
1119
approx_num_neighbors : Optional [int ] = None ,
1035
1120
fraction_leaf_nodes_to_search_override : Optional [float ] = None ,
1036
1121
return_full_datapoint : bool = False ,
1122
+ numeric_filter : Optional [List [NumericNamespace ]] = [],
1037
1123
) -> List [List [MatchNeighbor ]]:
1038
1124
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
1039
1125
@@ -1082,6 +1168,11 @@ def find_neighbors(
1082
1168
Note that returning full datapoint will significantly increase the
1083
1169
latency and cost of the query.
1084
1170
1171
+ numeric_filter (Optional[list[NumericNamespace]]):
1172
+ Optional. A list of NumericNamespaces for filtering the matching
1173
+ results. For example:
1174
+ [NumericNamespace(name="cost", value_int=5, op="GREATER")]
1175
+ will match datapoints that its cost is greater than 5.
1085
1176
Returns:
1086
1177
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
1087
1178
"""
@@ -1110,12 +1201,22 @@ def find_neighbors(
1110
1201
fraction_leaf_nodes_to_search_override
1111
1202
)
1112
1203
datapoint = gca_index_v1beta1 .IndexDatapoint (feature_vector = query )
1204
+ # Token restricts
1113
1205
for namespace in filter :
1114
1206
restrict = gca_index_v1beta1 .IndexDatapoint .Restriction ()
1115
1207
restrict .namespace = namespace .name
1116
1208
restrict .allow_list .extend (namespace .allow_tokens )
1117
1209
restrict .deny_list .extend (namespace .deny_tokens )
1118
1210
datapoint .restricts .append (restrict )
1211
+ # Numeric restricts
1212
+ for numeric_namespace in numeric_filter :
1213
+ numeric_restrict = gca_index_v1beta1 .IndexDatapoint .NumericRestriction ()
1214
+ numeric_restrict .namespace = numeric_namespace .name
1215
+ numeric_restrict .op = numeric_namespace .op
1216
+ numeric_restrict .value_int = numeric_namespace .value_int
1217
+ numeric_restrict .value_float = numeric_namespace .value_float
1218
+ numeric_restrict .value_double = numeric_namespace .value_double
1219
+ datapoint .numeric_restricts .append (numeric_restrict )
1119
1220
find_neighbors_query .datapoint = datapoint
1120
1221
find_neighbors_request .queries .append (find_neighbors_query )
1121
1222
0 commit comments