Skip to content

Commit 831c8e4

Browse files
jsondaicopybara-github
authored andcommitted
feat: add metric classes for 2 pairwise metrics for rapid evaluation SDK.
PiperOrigin-RevId: 644127034
1 parent 361b805 commit 831c8e4

File tree

4 files changed

+101
-7
lines changed

4 files changed

+101
-7
lines changed

tests/unit/vertexai/test_evaluation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def test_compute_pairwise_metrics_with_model_inference(self, api_transport):
332332
mock_candidate_model._model_name = "publishers/google/model/gemini-pro"
333333
test_metrics = [
334334
evaluation.PairwiseMetric(
335-
metric="summarization_quality",
335+
metric="pairwise_summarization_quality",
336336
baseline_model=mock_baseline_model,
337337
use_reference=False,
338338
)
@@ -609,11 +609,11 @@ def test_evaluate_pairwise_metrics_with_multiple_baseline_models(self):
609609
mock_candidate_model._model_name = "publishers/google/model/gemini-1.0-ultra"
610610
test_metrics = [
611611
evaluation.PairwiseMetric(
612-
metric="summarization_quality",
612+
metric="pairwise_summarization_quality",
613613
baseline_model=mock_baseline_model_1,
614614
),
615615
evaluation.PairwiseMetric(
616-
metric="summarization_quality",
616+
metric="pairwise_summarization_quality",
617617
baseline_model=mock_baseline_model_2,
618618
),
619619
]

vertexai/preview/evaluation/metrics/_base.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from typing import Any, Callable, Dict, Literal, Optional, Union
1919
from vertexai import generative_models
20+
from vertexai.preview.evaluation import constants
2021

2122

2223
class PairwiseMetric:
@@ -85,9 +86,12 @@ class PairwiseMetric:
8586
def __init__(
8687
self,
8788
*,
88-
metric: Literal["summarization_quality", "question_answering_quality"],
89-
baseline_model: Union[
90-
generative_models.GenerativeModel, Callable[[str], str]
89+
metric: Literal[
90+
constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY,
91+
constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY,
92+
],
93+
baseline_model: Optional[
94+
Union[generative_models.GenerativeModel, Callable[[str], str]]
9195
] = None,
9296
use_reference: bool = False,
9397
version: Optional[int] = None,
@@ -97,10 +101,14 @@ def __init__(
97101
Args:
98102
metric: The Side-by-side(SxS) pairwise evaluation metric name.
99103
baseline_model: The baseline model for the Side-by-side(SxS) comparison.
104+
If not specified, `baseline_model_response` column is required in the dataset.
100105
use_reference: Whether to use reference to compute the metric. If
101106
specified, the reference column is required in the dataset.
102107
version: The metric version to use for evaluation.
103108
"""
109+
# TODO(b/311221071): Remove the legacy metric names for GA.
110+
if metric in ("summarization_quality", "question_answering_quality"):
111+
metric = f"pairwise_{metric}"
104112
self._metric = metric
105113
self._baseline_model = baseline_model
106114
self._use_reference = use_reference
@@ -111,7 +119,7 @@ def __str__(self):
111119

112120
@property
113121
def pairwise_metric_name(self) -> str:
114-
return f"pairwise_{self._metric}"
122+
return self._metric
115123

116124
@property
117125
def baseline_model(
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Callable, Optional, Union
19+
from vertexai.generative_models import _generative_models
20+
from vertexai.preview.evaluation import constants
21+
from vertexai.preview.evaluation.metrics import _base
22+
23+
24+
class PairwiseQuestionAnsweringQuality(_base.PairwiseMetric):
25+
"""The Side-by-side(SxS) Pairwise Metric for Question Answering Quality."""
26+
27+
_metric_name = constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY
28+
29+
def __init__(
30+
self,
31+
*,
32+
baseline_model: Optional[
33+
Union[_generative_models.GenerativeModel, Callable[[str], str]]
34+
] = None,
35+
use_reference: bool = False,
36+
version: Optional[int] = None
37+
):
38+
super().__init__(
39+
metric=PairwiseQuestionAnsweringQuality._metric_name,
40+
baseline_model=baseline_model,
41+
use_reference=use_reference,
42+
version=version,
43+
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Callable, Optional, Union
19+
from vertexai.generative_models import _generative_models
20+
from vertexai.preview.evaluation import constants
21+
from vertexai.preview.evaluation.metrics import _base
22+
23+
24+
class PairwiseSummarizationQuality(_base.PairwiseMetric):
25+
"""The Side-by-side(SxS) Pairwise Metric for summarization quality."""
26+
27+
_metric_name = constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY
28+
29+
def __init__(
30+
self,
31+
*,
32+
baseline_model: Optional[
33+
Union[_generative_models.GenerativeModel, Callable[[str], str]]
34+
] = None,
35+
use_reference: bool = False,
36+
version: Optional[int] = None
37+
):
38+
super().__init__(
39+
metric=PairwiseSummarizationQuality._metric_name,
40+
baseline_model=baseline_model,
41+
use_reference=use_reference,
42+
version=version,
43+
)

0 commit comments

Comments
 (0)