Skip to content

Commit 2c93fc1

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add multithreading for custom metric computation.
PiperOrigin-RevId: 663889197
1 parent b78714f commit 2c93fc1

File tree

1 file changed

+28
-16
lines changed

1 file changed

+28
-16
lines changed

vertexai/preview/evaluation/_evaluation.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,36 +164,46 @@ def _compute_custom_metrics(
164164
row_dict: Dict[str, Any],
165165
custom_metrics: List[metrics_base.CustomMetric],
166166
pbar: tqdm,
167+
executor: futures.ThreadPoolExecutor,
167168
) -> Dict[str, Any]:
168169
"""Computes custom metrics for a row.
169170
170171
Args:
171172
row_dict: A dictionary of an instance in the eval dataset.
172173
custom_metrics: A list of CustomMetrics.
173174
pbar: A tqdm progress bar.
175+
executor: A thread pool executor.
174176
175177
Returns:
176178
A dictionary of an instance containing custom metric results.
177179
178180
Raises:
179181
KeyError: If the custom metric function does not return a valid output.
180182
"""
183+
futures_by_metric = collections.defaultdict(list)
181184
for custom_metric in custom_metrics:
182-
metric_output = custom_metric.metric_function(row_dict)
183-
pbar.update(1)
184-
if custom_metric.name in metric_output:
185-
row_dict[custom_metric.name] = metric_output[custom_metric.name]
186-
else:
187-
raise KeyError(
188-
f"Custom metric score `{custom_metric.name}` not found in the metric"
189-
f" output {metric_output}. Please make sure the custom metric"
190-
" function is valid, and the output dictionary uses"
191-
f" `{custom_metric.name}` as the key for metric value."
192-
)
193-
# Include additional metric results like explanation.
194-
for key, value in metric_output.items():
195-
if key != custom_metric.name:
196-
row_dict[f"{custom_metric.name}/{key}"] = value
185+
future = executor.submit(custom_metric.metric_function, row_dict)
186+
future.add_done_callback(lambda _: pbar.update(1))
187+
futures_by_metric[custom_metric].append(future)
188+
189+
for custom_metric, futures_list in futures_by_metric.items():
190+
for future in futures_list:
191+
metric_output = future.result()
192+
try:
193+
row_dict[
194+
f"{custom_metric.name}/{constants.MetricResult.SCORE_KEY}"
195+
] = metric_output[custom_metric.name]
196+
except KeyError:
197+
raise KeyError(
198+
f"Custom metric score `{custom_metric.name}` not found in the metric"
199+
f" output {metric_output}. Please make sure the custom metric"
200+
" function is valid, and the output dictionary uses"
201+
f" `{custom_metric.name}` as the key for metric value."
202+
)
203+
# Include additional metric results like explanation.
204+
for key, value in metric_output.items():
205+
if key != custom_metric.name:
206+
row_dict[f"{custom_metric.name}/{key}"] = value
197207
return row_dict
198208

199209

@@ -638,7 +648,9 @@ def _compute_metrics(
638648
with tqdm(total=total_request_count) as pbar:
639649
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
640650
for idx, row in evaluation_run_config.dataset.iterrows():
641-
row_dict = _compute_custom_metrics(row.to_dict(), custom_metrics, pbar)
651+
row_dict = _compute_custom_metrics(
652+
row.to_dict(), custom_metrics, pbar, executor
653+
)
642654

643655
instance_list.append(row_dict)
644656

0 commit comments

Comments
 (0)