Bootstrap RMSE Confidence Intervals in SQL (Presto/Athena) for multiple subgroups
In this blog post I want to show you how to calculate approximate confidence intervals in SQL (Presto/Athena dialect) using bootstrapping, specifically for the Root Mean Squared Error (RMSE) and in multiple subgroups.
This post builds on top of this blog post, Bootstrap Confidence Intervals in SQL for PostgreSQL and BigQuery, but extends it for multiple subgroups.
Final result
Starting from some data with columns (category, label, error)
where err
is the prediction error calculated as correct minus your predicted probability (err = correct - prediction
)
category | label | err |
---|---|---|
0 | 0 | 1 |
0 | 0 | 2 |
0 | 0 | 3 |
0 | 1 | 11 |
0 | 1 | 12 |
0 | 1 | 13 |
0 | 1 | 14 |
1 | 0 | 100 |
1 | 0 | 200 |
1 | 0 | 300 |
1 | 0 | 400 |
1 | 0 | 500 |
The query calculates the bootstrapped 95% RMSE CI
category | label | rmse_lo | rmse | rmse_hi |
---|---|---|---|---|
0 | 0 | 1.0 | 2.16 | 3.0 |
0 | 1 | 11.53 | 12.55 | 13.28 |
1 | 0 | 223.61 | 331.66 | 397.49 |
How
Letโs jump right in!
with labelled as (
select *
from (
values (0, 0,1),
(0, 0, 2),
(0, 0, 3),
(0, 1, 11),
(0, 1, 12),
(0, 1, 13),
(0, 1, 14),
(1, 0, 100),
(1, 0, 200),
(1, 0, 300),
(1, 0, 400),
(1, 0, 500)
) as t(category, label, err)
),
bootstrap_indexes as (
SELECT bootstrap_index
FROM (SELECT sequence(1, 10) as d)
CROSS JOIN UNNEST (d) as t(bootstrap_index)
),
bootstrap_data AS (
SELECT
category,
label,
err,
ROW_NUMBER() OVER(partition by category, label) - 1 AS data_index
FROM labelled
),
bootstrap_amounts as (
select
category,
label,
count(*) as amount
from bootstrap_data
group by category, label
),
bootstrap_map AS (
SELECT
d.category,
d.label,
a.amount,
cast(floor(random() * a.amount) as int) AS sampled_index,
bootstrap_index
FROM bootstrap_data d
JOIN bootstrap_amounts a
on a.category = d.category
and a.label = d.label
JOIN bootstrap_indexes ON TRUE
),
bootstrap AS (
SELECT
m.category,
m.label,
bootstrap_index,
err
FROM bootstrap_map m
JOIN bootstrap_data d
on m.category = d.category
and m.label = d.label
and m.sampled_index = d.data_index
),
bootstrap_aggregated AS (
SELECT
category,
label,
bootstrap_index,
sqrt(avg(pow(err, 2))) as rmse
FROM bootstrap
group by category, label, bootstrap_index
),
bootstrap_ci AS (
SELECT
category,
label,
approx_percentile(rmse, 0.025) as rmse_lo,
approx_percentile(rmse, 0.975) as rmse_hi
FROM bootstrap_aggregated
group by category, label
),
sample AS (
SELECT
category,
label,
sqrt(avg(pow(err, 2))) AS rmse_avg
FROM labelled
group by category, label
)
SELECT
s.category,
s.label,
round(rmse_lo,2) as rmse_lo,
round(rmse_avg, 2) as rmse,
round(rmse_hi, 2) as rmse_hi
FROM sample s
JOIN bootstrap_ci c
on s.category = c.category
and s.label = c.label
order by category, label
Letโs break this query down one CTE at a time.
labelled
creates our data.
category | label | err |
---|---|---|
0 | 0 | 1 |
0 | 0 | 2 |
0 | 0 | 3 |
โฆ | .. | .. |
1 | 0 | 500 |
bootstrap_indexes
enumerates the bootstrap samples you want to do.
bootstrap_index |
---|
1 |
โฆ |
9 |
10 |
bootstrap_data
enumerates each row in each category/label subgroup. We need this query because we are going to use data_index
to resample.
category | label | err | data_index |
---|---|---|---|
0 | 1 | 11 | 0 |
0 | 1 | 12 | 1 |
0 | 1 | 13 | 2 |
0 | 1 | 14 | 3 |
0 | 0 | 1 | 0 |
0 | 0 | 2 | 1 |
0 | 0 | 3 | 2 |
1 | 0 | 100 | 0 |
1 | 0 | 200 | 1 |
1 | 0 | 300 | 2 |
1 | 0 | 400 | 3 |
1 | 0 | 500 | 4 |
bootstrap_amounts
contains the number of rows in each category/label subgroup.
category | label | amount |
---|---|---|
0 | 1 | 4 |
0 | 0 | 3 |
1 | 0 | 5 |
bootstrap_map
performs the resampling with replacement.
For each category/label subgroup we take a full sample by sampling random integers in the
range of amount
(which differs per subgroup).
The JOIN bootstrap_indexes ON TRUE
gives the full cartesian product of the bootstrap and data indexes, in this case we have 10 bootstrap samples and 12 observations which results in 120 rows.
category | label | amount | sampled_index | bootstrap_index |
---|---|---|---|---|
0 | 0 | 3 | 0 | 1 |
0 | 0 | 3 | 2 | 1 |
0 | 0 | 3 | 1 | 1 |
0 | 1 | 4 | 3 | 1 |
.. | .. | .. | .. | .. |
0 | 1 | 3 | 1 | 9 |
bootstrap
joins back the actual data instead of the sampled index. This bootstrap
table now contains 10 full new samples we bootstrapped!
category | label | bootstrap_index | err |
---|---|---|---|
0 | 0 | 1 | 1 |
0 | 0 | 1 | 1 |
0 | 0 | 1 | 1 |
0 | 1 | 1 | 12 |
.. | .. | .. | .. |
0 | 1 | 9 | 13 |
bootstrap_aggregated
calculates the RMSE for each individual sample (i.e.
bootstrap_index
).
category | label | bootstrap_index | rmse |
---|---|---|---|
0 | 0 | 2 | 3.0 |
0 | 0 | 4 | 2.71 |
0 | 0 | 5 | 2.71 |
0 | 0 | 9 | 1.41 |
0 | 0 | 1 | 2.16 |
0 | 0 | 6 | 2.52 |
0 | 0 | 8 | 2.16 |
0 | 0 | 10 | 2.16 |
1 | 0 | 4 | 428.95 |
0 | 1 | 2 | 13.04 |
.. | .. | .. | .. |
1 | 0 | 10 | 240.83 |
bootstrap_ci
uses approx_percentile()
to find the percentiles of the empirical bootstrap distribution that we need.
category | label | rmse_lo | rmse_hi |
---|---|---|---|
0 | 0 | 1.41 | 2.52 |
1 | 0 | 249.0 | 500.0 |
0 | 1 | 11.51 | 13.53 |
sample
computes the actual observed RMSE in the category/label subgroups.
Note that I calculate the RMSE from the actual observations here and not the
bootstrapped RMSE!
category | label | rmse_avg |
---|---|---|
0 | 1 | 12.55 |
0 | 0 | 2.16 |
1 | 0 | 331.66 |
Finally, we can combine everything together: the bootstrapped RMSE 95% CI and the observed RMSE for every category/label subgroup.
category | label | rmse_lo | rmse | rmse_hi |
---|---|---|---|---|
0 | 0 | 1.41 | 2.16 | 2.71 |
0 | 1 | 11.51 | 12.55 | 13.51 |
1 | 0 | 249.0 | 331.66 | 428.95 |
Optional: Simple case (without subgroups)
For the simple case (without subgroups) I want to refer you back to the original blog post.
For completeness Iโll add it in. Imagine you have a table with prediction errors like this:
err |
---|
100 |
200 |
300 |
400 |
500 |
And you want to turn it into this (where rmse_lo
and rmse_hi
are bootstrapped RMSE 95% CI.):
rmse_lo | rmse | rmse_hi |
---|---|---|
279.28 | 331.66 | 426.61 |
with labelled as (
select *
from (
values
(100),
(200),
(300),
(400),
(500)
) as t(err)
),
bootstrap_indexes as (
SELECT bootstrap_index
FROM (SELECT sequence(1, 10) as d)
CROSS JOIN UNNEST (d) as t(bootstrap_index)
),
bootstrap_data AS (
SELECT
err,
ROW_NUMBER() OVER() - 1 AS data_index
FROM labelled
),
bootstrap_map AS (
SELECT
cast(floor(random() * (select count(data_index) from bootstrap_data)) as int) AS sampled_index,
bootstrap_index
FROM bootstrap_data d
JOIN bootstrap_indexes ON TRUE
),
bootstrap AS (
SELECT
bootstrap_index,
err
FROM bootstrap_map m
JOIN bootstrap_data d on m.sampled_index = d.data_index
),
bootstrap_aggregated AS (
SELECT
bootstrap_index,
sqrt(avg(pow(err, 2))) as rmse
FROM bootstrap
group by bootstrap_index
),
bootstrap_ci AS (
SELECT
approx_percentile(rmse, 0.025) as rmse_lo,
approx_percentile(rmse, 0.975) as rmse_hi
FROM bootstrap_aggregated
),
sample AS (
SELECT
sqrt(avg(pow(err, 2))) AS rmse_avg
FROM labelled
)
SELECT
round(rmse_lo,2) as rmse_lo,
round(rmse_avg, 2) as rmse,
round(rmse_hi, 2) as rmse_hi
FROM sample s
JOIN bootstrap_ci on TRUE
Comments