Skip to content

Commit 1ab4344

Browse files
jsondaicopybara-github
authored andcommitted
feat: Add notebook helper functions to preview eval SDK to display and visualize evaluation results in an IPython environment
PiperOrigin-RevId: 725404155
1 parent 0abe0b7 commit 1ab4344

File tree

3 files changed

+259
-0
lines changed

3 files changed

+259
-0
lines changed

vertexai/evaluation/eval_task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import logging
1718
from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union
1819
import uuid
20+
import warnings
1921

2022
from google.api_core import exceptions
2123
import vertexai
@@ -47,6 +49,8 @@
4749
IPython_display = None
4850

4951
_LOGGER = base.Logger(__name__)
52+
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
53+
warnings.filterwarnings("ignore")
5054

5155
EvalResult = eval_base.EvalResult
5256
GenerativeModel = generative_models.GenerativeModel

vertexai/preview/evaluation/eval_task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717
"""Evaluation Task class."""
1818

19+
import logging
20+
import warnings
1921
from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union
2022
import uuid
2123

@@ -48,6 +50,8 @@
4850
IPython_display = None
4951

5052
_LOGGER = base.Logger(__name__)
53+
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
54+
warnings.filterwarnings("ignore")
5155

5256
AutoraterConfig = eval_base.AutoraterConfig
5357
EvalResult = eval_base.EvalResult
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2025 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""Python functions which run only within a Jupyter or Colab notebook."""
18+
19+
import random
20+
import string
21+
import sys
22+
from typing import List, Optional, Tuple
23+
24+
from vertexai.preview.evaluation import _base as eval_base
25+
from vertexai.preview.evaluation import constants
26+
27+
# pylint: disable=g-import-not-at-top
28+
try:
29+
import pandas as pd
30+
except ImportError:
31+
pandas = None
32+
33+
_MARKDOWN_H2 = "##"
34+
_MARKDOWN_H3 = "###"
35+
_DEFAULT_COLUMNS_TO_DISPLAY = [
36+
constants.Dataset.MODEL_RESPONSE_COLUMN,
37+
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
38+
constants.Dataset.PROMPT_COLUMN,
39+
constants.MetricResult.ROW_COUNT_KEY,
40+
]
41+
_DEFAULT_RADAR_RANGE = (0, 5)
42+
43+
44+
def _get_ipython_shell_name() -> str:
45+
if "IPython" in sys.modules:
46+
# pylint: disable=g-import-not-at-top, g-importing-member
47+
from IPython import get_ipython
48+
49+
return get_ipython().__class__.__name__
50+
return ""
51+
52+
53+
def is_ipython_available() -> bool:
54+
return _get_ipython_shell_name()
55+
56+
57+
def _filter_df(
58+
df: pd.DataFrame, substrings: Optional[List[str]] = None
59+
) -> pd.DataFrame:
60+
"""Filters a DataFrame to include only columns containing the given substrings."""
61+
if substrings is None:
62+
return df
63+
64+
return df.copy().filter(
65+
[
66+
column_name
67+
for column_name in df.columns
68+
if any(substring in column_name for substring in substrings)
69+
]
70+
)
71+
72+
73+
def display_eval_result(
74+
eval_result: "eval_base.EvalResult",
75+
title: Optional[str] = None,
76+
metrics: Optional[List[str]] = None,
77+
) -> None:
78+
"""Displays evaluation results in a notebook using IPython.display.
79+
80+
Args:
81+
eval_result: An object containing evaluation results with
82+
`summary_metrics` and `metrics_table` attributes.
83+
title: A string title to display above the results.
84+
metrics: A list of metric name substrings to filter displayed columns. If
85+
provided, only metrics whose names contain any of these strings will be
86+
displayed.
87+
"""
88+
if not is_ipython_available():
89+
return
90+
# pylint: disable=g-import-not-at-top, g-importing-member
91+
from IPython.display import display
92+
from IPython.display import Markdown
93+
94+
summary_metrics, metrics_table = (
95+
eval_result.summary_metrics,
96+
eval_result.metrics_table,
97+
)
98+
99+
summary_metrics_df = pd.DataFrame.from_dict(summary_metrics, orient="index").T
100+
101+
if metrics:
102+
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
103+
summary_metrics_df = _filter_df(summary_metrics_df, columns_to_keep)
104+
metrics_table = _filter_df(metrics_table, columns_to_keep)
105+
106+
# Display the title in Markdown.
107+
if title:
108+
display(Markdown(f"{_MARKDOWN_H2} {title}"))
109+
110+
# Display the summary metrics.
111+
display(Markdown(f"{_MARKDOWN_H3} Summary Metrics"))
112+
display(summary_metrics_df)
113+
114+
# Display the metrics table.
115+
display(Markdown(f"{_MARKDOWN_H3} Row-based Metrics"))
116+
display(metrics_table)
117+
118+
119+
def display_explanations(
120+
eval_result: "eval_base.EvalResult",
121+
num: int = 1,
122+
metrics: Optional[List[str]] = None,
123+
) -> None:
124+
"""Displays the explanations in a notebook using IPython.display.
125+
126+
Args:
127+
eval_result: An object containing evaluation results. It is expected to
128+
have attributes `summary_metrics` and `metrics_table`.
129+
num: The number of row samples to display. Defaults to 1. If the number of
130+
rows is less than `num`, all rows will be displayed.
131+
metrics: A list of metric name substrings to filter displayed columns. If
132+
provided, only metrics whose names contain any of these strings will be
133+
displayed.
134+
"""
135+
if not is_ipython_available():
136+
return
137+
# pylint: disable=g-import-not-at-top, g-importing-member
138+
from IPython.display import display
139+
from IPython.display import HTML
140+
141+
style = "white-space: pre-wrap; width: 1500px; overflow-x: auto;"
142+
metrics_table = eval_result.metrics_table
143+
144+
if num < 1:
145+
raise ValueError("Num must be greater than 0.")
146+
num = min(num, len(metrics_table))
147+
148+
df = metrics_table.sample(n=num)
149+
150+
if metrics:
151+
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
152+
df = _filter_df(df, columns_to_keep)
153+
154+
for _, row in df.iterrows():
155+
for col in df.columns:
156+
display(HTML(f"<div style='{style}'><h4>{col}:</h4>{row[col]}</div>"))
157+
display(HTML("<hr>"))
158+
159+
160+
def display_radar_plot(
161+
eval_results_with_title: List[Tuple[str, "eval_base.EvalResult"]],
162+
metrics: List[str],
163+
radar_range: Tuple[float, float] = _DEFAULT_RADAR_RANGE,
164+
) -> None:
165+
"""Plots a radar plot comparing evaluation results.
166+
167+
Args:
168+
eval_results_with_title: List of (title, eval_result) tuples.
169+
metrics: A list of metrics whose mean values will be plotted.
170+
radar_range: Range of the radar plot axes.
171+
"""
172+
# pylint: disable=g-import-not-at-top
173+
try:
174+
import plotly.graph_objects as go
175+
except ImportError as exc:
176+
raise ImportError(
177+
'`plotly` is not installed. Please install using "!pip install plotly"'
178+
) from exc
179+
180+
fig = go.Figure()
181+
for title, eval_result in eval_results_with_title:
182+
summary_metrics = eval_result.summary_metrics
183+
if metrics:
184+
summary_metrics = {
185+
key.replace("/mean", ""): summary_metrics[key]
186+
for key in summary_metrics
187+
if any(selected_metric + "/mean" in key for selected_metric in metrics)
188+
}
189+
fig.add_trace(
190+
go.Scatterpolar(
191+
r=list(summary_metrics.values()),
192+
theta=list(summary_metrics.keys()),
193+
fill="toself",
194+
name=title,
195+
)
196+
)
197+
fig.update_layout(
198+
polar=dict(radialaxis=dict(visible=True, range=radar_range)),
199+
showlegend=True,
200+
)
201+
fig.show()
202+
203+
204+
def display_bar_plot(
205+
eval_results_with_title: List[Tuple[str, "eval_base.EvalResult"]],
206+
metrics: List[str],
207+
) -> None:
208+
"""Plots a bar plot comparing evaluation results.
209+
210+
Args:
211+
eval_results_with_title: List of (title, eval_result) tuples.
212+
metrics: A list of metrics whose mean values will be plotted.
213+
"""
214+
215+
# pylint: disable=g-import-not-at-top
216+
try:
217+
import plotly.graph_objects as go
218+
except ImportError as exc:
219+
raise ImportError(
220+
'`plotly` is not installed. Please install using "!pip install plotly"'
221+
) from exc
222+
223+
data = []
224+
225+
for title, eval_result in eval_results_with_title:
226+
summary_metrics = eval_result.summary_metrics
227+
mean_summary_metrics = [f"{metric}/mean" for metric in metrics]
228+
updated_summary_metrics = []
229+
if metrics:
230+
for k, v in summary_metrics.items():
231+
if k in mean_summary_metrics:
232+
updated_summary_metrics.append((k, v))
233+
summary_metrics = dict(updated_summary_metrics)
234+
235+
data.append(
236+
go.Bar(
237+
x=list(summary_metrics.keys()),
238+
y=list(summary_metrics.values()),
239+
name=title,
240+
)
241+
)
242+
243+
fig = go.Figure(data=data)
244+
245+
fig.update_layout(barmode="group", showlegend=True)
246+
fig.show()
247+
248+
249+
def generate_uuid(length: int = 8) -> str:
250+
"""Generates a uuid of a specified length (default=8)."""
251+
return "".join(random.choices(string.ascii_lowercase + string.digits, k=length))

0 commit comments

Comments
 (0)