Skip to content

Commit e7b83f1

Browse files
rey-esptswast
andauthored
docs: add snippet for predicting classifications using a boosted tree model (#1156)
* docs: add snippet for predicting classifications using a boosted tree model * merge and rename bigquery_dataframes_bqml_boosted_tree_explain to bigquery_dataframes_bqml_boosted_tree_evaluate * remove training | * clean up asserts --------- Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 9d8970a commit e7b83f1

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

samples/snippets/classification_boosted_tree_model_test.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_boosted_tree_model(random_model_id: str) -> None:
6262
replace=True,
6363
)
6464
# [END bigquery_dataframes_bqml_boosted_tree_create]
65-
# [START bigquery_dataframes_bqml_boosted_tree_explain]
65+
# [START bigquery_dataframes_bqml_boosted_tree_evaluate]
6666
# Select model you'll use for predictions. `read_gbq_model` loads model
6767
# data from BigQuery, but you could also use the `tree_model` object
6868
# from the previous step.
@@ -82,8 +82,33 @@ def test_boosted_tree_model(random_model_id: str) -> None:
8282
# Output:
8383
# precision recall accuracy f1_score log_loss roc_auc
8484
# 0 0.671924 0.578804 0.839429 0.621897 0.344054 0.887335
85-
# [END bigquery_dataframes_bqml_boosted_tree_explain]
85+
# [END bigquery_dataframes_bqml_boosted_tree_evaluate]
86+
# [START bigquery_dataframes_bqml_boosted_tree_predict]
87+
# Select model you'll use for predictions. `read_gbq_model` loads model
88+
# data from BigQuery, but you could also use the `tree_model` object
89+
# from previous steps.
90+
tree_model = bpd.read_gbq_model(
91+
your_model_id, # For example: "your-project.bqml_tutorial.tree_model"
92+
)
93+
94+
# input_data is defined in an earlier step.
95+
prediction_data = input_data[input_data["dataframe"] == "prediction"]
96+
97+
predictions = tree_model.predict(prediction_data)
98+
predictions.peek()
99+
# Output:
100+
# predicted_income_bracket predicted_income_bracket_probs.label predicted_income_bracket_probs.prob
101+
# <=50K >50K 0.05183430016040802
102+
# <50K 0.94816571474075317
103+
# <=50K >50K 0.00365859130397439
104+
# <50K 0.99634140729904175
105+
# <=50K >50K 0.037775970995426178
106+
# <50K 0.96222406625747681
107+
# [END bigquery_dataframes_bqml_boosted_tree_predict]
108+
assert input_data is not None
109+
assert training_data is not None
86110
assert tree_model is not None
87111
assert evaluation_data is not None
88112
assert score is not None
89-
assert input_data is not None
113+
assert prediction_data is not None
114+
assert predictions is not None

0 commit comments

Comments
 (0)