@@ -164,36 +164,46 @@ def _compute_custom_metrics(
164
164
row_dict : Dict [str , Any ],
165
165
custom_metrics : List [metrics_base .CustomMetric ],
166
166
pbar : tqdm ,
167
+ executor : futures .ThreadPoolExecutor ,
167
168
) -> Dict [str , Any ]:
168
169
"""Computes custom metrics for a row.
169
170
170
171
Args:
171
172
row_dict: A dictionary of an instance in the eval dataset.
172
173
custom_metrics: A list of CustomMetrics.
173
174
pbar: A tqdm progress bar.
175
+ executor: A thread pool executor.
174
176
175
177
Returns:
176
178
A dictionary of an instance containing custom metric results.
177
179
178
180
Raises:
179
181
KeyError: If the custom metric function does not return a valid output.
180
182
"""
183
+ futures_by_metric = collections .defaultdict (list )
181
184
for custom_metric in custom_metrics :
182
- metric_output = custom_metric .metric_function (row_dict )
183
- pbar .update (1 )
184
- if custom_metric .name in metric_output :
185
- row_dict [custom_metric .name ] = metric_output [custom_metric .name ]
186
- else :
187
- raise KeyError (
188
- f"Custom metric score `{ custom_metric .name } ` not found in the metric"
189
- f" output { metric_output } . Please make sure the custom metric"
190
- " function is valid, and the output dictionary uses"
191
- f" `{ custom_metric .name } ` as the key for metric value."
192
- )
193
- # Include additional metric results like explanation.
194
- for key , value in metric_output .items ():
195
- if key != custom_metric .name :
196
- row_dict [f"{ custom_metric .name } /{ key } " ] = value
185
+ future = executor .submit (custom_metric .metric_function , row_dict )
186
+ future .add_done_callback (lambda _ : pbar .update (1 ))
187
+ futures_by_metric [custom_metric ].append (future )
188
+
189
+ for custom_metric , futures_list in futures_by_metric .items ():
190
+ for future in futures_list :
191
+ metric_output = future .result ()
192
+ try :
193
+ row_dict [
194
+ f"{ custom_metric .name } /{ constants .MetricResult .SCORE_KEY } "
195
+ ] = metric_output [custom_metric .name ]
196
+ except KeyError :
197
+ raise KeyError (
198
+ f"Custom metric score `{ custom_metric .name } ` not found in the metric"
199
+ f" output { metric_output } . Please make sure the custom metric"
200
+ " function is valid, and the output dictionary uses"
201
+ f" `{ custom_metric .name } ` as the key for metric value."
202
+ )
203
+ # Include additional metric results like explanation.
204
+ for key , value in metric_output .items ():
205
+ if key != custom_metric .name :
206
+ row_dict [f"{ custom_metric .name } /{ key } " ] = value
197
207
return row_dict
198
208
199
209
@@ -638,7 +648,9 @@ def _compute_metrics(
638
648
with tqdm (total = total_request_count ) as pbar :
639
649
with futures .ThreadPoolExecutor (max_workers = constants .MAX_WORKERS ) as executor :
640
650
for idx , row in evaluation_run_config .dataset .iterrows ():
641
- row_dict = _compute_custom_metrics (row .to_dict (), custom_metrics , pbar )
651
+ row_dict = _compute_custom_metrics (
652
+ row .to_dict (), custom_metrics , pbar , executor
653
+ )
642
654
643
655
instance_list .append (row_dict )
644
656
0 commit comments