Source code for incerto.conformal.methods
"""
incerto.conformal.methods
-------------------------
Stateless helper functions that wrap a *trained* base model and produce
prediction sets or intervals at a user-specified mis-coverage rate α.
Each method adds only what is necessary for conformal inference—no
optimisers, schedulers, or training loops are defined here.
"""
from __future__ import annotations
from typing import Callable, Tuple, List, Union
import math
import torch
def _validate_alpha(alpha: float) -> None:
"""Validate that alpha is in (0, 1)."""
if not (0.0 < alpha < 1.0):
raise ValueError(f"alpha must be in (0, 1), got {alpha}")
def _conformal_quantile(scores: torch.Tensor, alpha: float) -> torch.Tensor:
"""Compute the conformal quantile with exact finite-sample correction.
Returns the ⌈(1 − α)(n + 1)⌉-th smallest score, which guarantees
1 − α coverage under exchangeability.
"""
n = len(scores)
sorted_scores = torch.sort(scores)[0]
k = math.ceil((1 - alpha) * (n + 1)) # 1-indexed
k = max(1, min(k, n))
return sorted_scores[k - 1]
[docs]
class ConformalPredictor:
"""Thin wrapper around a calibrated conformal predictor.
Provides a consistent object-oriented interface for both classification
(prediction sets) and regression (intervals) conformal methods.
Args:
predictor: Callable returned by a conformal method function.
method: Name of the conformal method used.
alpha: Miscoverage rate used during calibration.
Example:
>>> cp = ConformalPredictor.from_method(
... "raps", model=model, calib_loader=loader, alpha=0.1
... )
>>> pred_sets = cp.predict(x_test)
"""
[docs]
def __init__(
self,
predictor: Callable,
method: str = "unknown",
alpha: float = 0.0,
):
self._predictor = predictor
self.method = method
self.alpha = alpha
[docs]
def predict(
self, x: torch.Tensor
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""Run the calibrated predictor on new inputs.
For classification: returns ``List[torch.Tensor]`` of prediction sets.
For regression: returns ``(lower, upper)`` interval bounds.
"""
return self._predictor(x)
def __call__(
self, x: torch.Tensor
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
return self.predict(x)
def __repr__(self) -> str:
return f"ConformalPredictor(method='{self.method}', alpha={self.alpha})"
[docs]
@classmethod
def from_method(cls, method: str, **kwargs) -> "ConformalPredictor":
"""Convenience factory that calls the named method and wraps the result.
Args:
method: One of ``'inductive_conformal'``, ``'mondrian_conformal'``,
``'aps'``, ``'raps'``, ``'jackknife_plus'``, ``'cv_plus'``,
``'conformalized_quantile_regression'``.
**kwargs: Arguments forwarded to the method function.
Returns:
A :class:`ConformalPredictor` wrapping the calibrated predictor.
"""
methods = {
"inductive_conformal": inductive_conformal,
"mondrian_conformal": mondrian_conformal,
"aps": aps,
"raps": raps,
"jackknife_plus": jackknife_plus,
"cv_plus": cv_plus,
"conformalized_quantile_regression": conformalized_quantile_regression,
}
if method not in methods:
raise ValueError(f"Unknown method '{method}'. Choose from: {list(methods)}")
alpha = kwargs.get("alpha", 0.0)
predictor = methods[method](**kwargs)
return cls(predictor, method=method, alpha=alpha)
[docs]
@torch.no_grad()
def inductive_conformal(
model: torch.nn.Module,
calib_loader: torch.utils.data.DataLoader,
alpha: float,
) -> Callable[[torch.Tensor], List[torch.Tensor]]:
"""
Classical Inductive Conformal Prediction (ICP)
— Vovk, Gammerman, and Shafer, *Algorithmic Learning in a Random World* (2005).
Returns a predictor f̂(x) that outputs a prediction set for classification.
"""
_validate_alpha(alpha)
model.eval()
scores = []
for x, y in calib_loader:
logits = model(x)
conf = torch.softmax(logits, dim=-1)
# conformity score: 1 − probability assigned to the true class
scores.append(1.0 - conf[torch.arange(len(y)), y])
all_scores = torch.cat(scores)
qhat = _conformal_quantile(all_scores, alpha)
def predictor(x: torch.Tensor) -> List[torch.Tensor]:
model.eval()
with torch.no_grad():
logits = model(x)
conf = torch.softmax(logits, dim=-1)
return [(conf_i >= 1.0 - qhat).nonzero().squeeze(-1) for conf_i in conf]
return predictor
[docs]
@torch.no_grad()
def mondrian_conformal(
model: torch.nn.Module,
calib_loader: torch.utils.data.DataLoader,
alpha: float,
partition_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> Callable[[torch.Tensor], List[torch.Tensor]]:
"""
Mondrian Conformal Prediction
— Papadopoulos, *Reliable Classification with Conformal Predictors* (2008).
Allows per-class (or arbitrary partition) calibration to guarantee *conditional*
coverage within each partition cell.
"""
_validate_alpha(alpha)
if partition_fn is None:
# default: partition by true label
partition_fn = lambda x, y: y # noqa: E731
model.eval()
parts = {}
for x, y in calib_loader:
logits = model(x)
conf = torch.softmax(logits, dim=-1)
part = partition_fn(x, y)
for p, c, yy in zip(part, conf, y):
scores = parts.setdefault(int(p), [])
scores.append(1.0 - c[yy])
qhats = {k: _conformal_quantile(torch.tensor(v), alpha) for k, v in parts.items()}
# Conservative fallback for partitions unseen during calibration:
# qhat=1.0 means threshold 1−1=0, so every class is included.
_default_qhat = torch.tensor(1.0)
def predictor(x: torch.Tensor) -> List[torch.Tensor]:
model.eval()
with torch.no_grad():
logits = model(x)
conf = torch.softmax(logits, dim=-1)
part = partition_fn(x, logits.argmax(-1))
return [
(ci >= 1.0 - qhats.get(int(pi), _default_qhat)).nonzero().squeeze(-1)
for ci, pi in zip(conf, part)
]
return predictor
[docs]
@torch.no_grad()
def aps(
model: torch.nn.Module,
calib_loader: torch.utils.data.DataLoader,
alpha: float,
) -> Callable[[torch.Tensor], List[torch.Tensor]]:
"""
Adaptive Prediction Sets (APS)
— Romano, Patterson, and Candes, *NeurIPS 2020*.
Produces variable-sized sets by thresholding randomized cumulative
probability mass, calibrated on held-out data. Randomized scores
yield tight (non-conservative) coverage guarantees.
"""
_validate_alpha(alpha)
model.eval()
scores = []
for x, y in calib_loader:
logits = model(x)
probs = torch.softmax(logits, dim=-1)
probs_sorted, idx_sorted = torch.sort(probs, descending=True, dim=-1)
cumprobs = probs_sorted.cumsum(dim=-1)
# Find where true label appears in sorted order
ranks = (idx_sorted == y.unsqueeze(-1)).nonzero(as_tuple=True)[1]
# Randomised score: cumprob *before* the true label + U·prob_true
U = torch.rand(len(y), device=probs.device)
cumprobs_prev = torch.zeros_like(cumprobs)
cumprobs_prev[:, 1:] = cumprobs[:, :-1]
scores.append(
cumprobs_prev[torch.arange(len(y)), ranks]
+ U * probs_sorted[torch.arange(len(y)), ranks]
)
all_scores = torch.cat(scores)
qhat = _conformal_quantile(all_scores, alpha)
def predictor(x: torch.Tensor) -> List[torch.Tensor]:
model.eval()
with torch.no_grad():
logits = model(x)
probs, idx = torch.sort(
torch.softmax(logits, dim=-1), descending=True, dim=-1
)
cumprobs = probs.cumsum(dim=-1)
# Randomised inclusion matching the calibration score
V = torch.rand(x.size(0), 1, device=probs.device)
cumprobs_prev = torch.zeros_like(cumprobs)
cumprobs_prev[:, 1:] = cumprobs[:, :-1]
randomized = cumprobs_prev + V * probs
sets = []
for idx_i, rand_i in zip(idx, randomized):
sets.append(idx_i[rand_i <= qhat])
return sets
return predictor
[docs]
@torch.no_grad()
def raps(
model: torch.nn.Module,
calib_loader: torch.utils.data.DataLoader,
alpha: float,
lam: float = 0.0,
k_reg: int = 1,
) -> Callable[[torch.Tensor], List[torch.Tensor]]:
"""
Regularized APS (RAPS)
— Angelopoulos, Bates, Malik, and Jordan, *ICLR 2021*.
Adds ℓ₁ regularisation (λ) beyond rank k_reg and minimum size constraint to APS.
"""
_validate_alpha(alpha)
model.eval()
# compute calibration scores following RAPS definition
scores = []
for x, y in calib_loader:
logits = model(x)
probs, idx = torch.sort(torch.softmax(logits, dim=-1), descending=True)
rank = (idx == y[:, None]).nonzero()[:, 1]
ranks = torch.arange(1, probs.size(-1) + 1, device=probs.device)
penalty = lam * torch.clamp(ranks - k_reg, min=0)
g = probs.cumsum(dim=-1) + penalty
scores.append(g[torch.arange(len(y)), rank])
all_scores = torch.cat(scores)
qhat = _conformal_quantile(all_scores, alpha)
def predictor(x: torch.Tensor) -> List[torch.Tensor]:
model.eval()
with torch.no_grad():
logits = model(x)
probs, idx = torch.sort(torch.softmax(logits, dim=-1), descending=True)
ranks = torch.arange(1, probs.size(-1) + 1, device=probs.device)
penalty = lam * torch.clamp(ranks - k_reg, min=0)
g = probs.cumsum(dim=-1) + penalty
S = (g <= qhat).long()
# enforce minimum size k_reg
mask_min = torch.arange(probs.size(-1), device=probs.device)[None] < k_reg
S = S | mask_min
return [(idx_i[S_i == 1]).clone().detach() for idx_i, S_i in zip(idx, S)]
return predictor
# -------- Regression flavours (absolute residual conformity) -------- #
[docs]
def jackknife_plus(
model_fn: Callable[[torch.utils.data.Dataset], torch.nn.Module],
train_dataset: torch.utils.data.Dataset,
alpha: float,
) -> Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""
Jackknife+ Intervals
— Barber, Candès, and Ramdas, *Ann. Stat.* 2021.
Produces prediction intervals using leave-one-out cross-validation with
the Jackknife+ aggregation rule. All n LOO models are retained and used
at prediction time to construct valid intervals.
Args:
model_fn: Function that trains a model on a provided dataset subset.
train_dataset: Training dataset used for LOO calibration.
alpha: Miscoverage rate (e.g., 0.1 for 90% coverage).
"""
_validate_alpha(alpha)
n = len(train_dataset)
loo_models: List[torch.nn.Module] = []
residuals = torch.empty(n)
for i in range(n):
leave_one_out = torch.utils.data.Subset(
train_dataset, [j for j in range(n) if j != i]
)
model = model_fn(leave_one_out)
model.eval()
loo_models.append(model)
xi, yi = train_dataset[i]
with torch.no_grad():
pred = model(xi.unsqueeze(0)).squeeze().cpu()
residuals[i] = torch.abs(pred - yi)
def predictor(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
loo_preds = []
for i in range(n):
loo_models[i].eval()
with torch.no_grad():
loo_preds.append(loo_models[i](x).squeeze())
loo_preds = torch.stack(loo_preds, dim=0) # (n, *test_shape)
r = residuals
if loo_preds.dim() > 1:
r = residuals.unsqueeze(-1)
lower_vals = loo_preds - r
upper_vals = loo_preds + r
# Jackknife+ quantiles (Barber et al., 2021, Theorem 1)
# Lower: ⌈α(n+1)⌉-th smallest of {-∞} ∪ lower_vals
# Upper: ⌈(1-α)(n+1)⌉-th smallest of upper_vals ∪ {+∞}
lower_sorted = torch.sort(lower_vals, dim=0)[0]
upper_sorted = torch.sort(upper_vals, dim=0)[0]
k_lo = math.ceil(alpha * (n + 1)) - 2 # 0-indexed into actual values
if k_lo < 0:
lower = torch.full_like(lower_sorted[0], float("-inf"))
else:
lower = lower_sorted[k_lo]
k_hi = math.ceil((1 - alpha) * (n + 1)) - 1 # 0-indexed
if k_hi >= n:
upper = torch.full_like(upper_sorted[0], float("inf"))
else:
upper = upper_sorted[k_hi]
return lower, upper
return predictor
[docs]
def cv_plus(
model_fn: Callable[[torch.utils.data.Dataset], torch.nn.Module],
train_dataset: torch.utils.data.Dataset,
folds: int,
alpha: float,
) -> Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""
Cross-Validation+ Intervals (CV+)
— Barber et al., *Ann. Stat.* 2021.
Uses k-fold cross-validation to produce prediction intervals. All K
fold models are retained and used at prediction time via the CV+
aggregation rule.
Args:
model_fn: Function that trains a model on a provided dataset subset.
train_dataset: Training dataset used for k-fold calibration.
folds: Number of cross-validation folds.
alpha: Miscoverage rate (e.g., 0.1 for 90% coverage).
"""
_validate_alpha(alpha)
n = len(train_dataset)
idx = torch.randperm(n)
fold_sizes = [(n + i) // folds for i in range(folds)]
fold_models: List[torch.nn.Module] = []
residuals = torch.empty(n)
sample_to_fold = torch.empty(n, dtype=torch.long)
offset = 0
for k in range(folds):
val_idx = idx[offset : offset + fold_sizes[k]]
offset += fold_sizes[k]
val_set = set(val_idx.tolist())
train_idx = [i for i in idx.tolist() if i not in val_set]
model = model_fn(torch.utils.data.Subset(train_dataset, train_idx))
model.eval()
fold_models.append(model)
with torch.no_grad():
Xk = torch.stack([train_dataset[i][0] for i in val_idx])
yk = torch.tensor([train_dataset[i][1] for i in val_idx])
preds = model(Xk).squeeze()
fold_residuals = torch.abs(yk - preds)
for j, vi in enumerate(val_idx.tolist()):
residuals[vi] = fold_residuals[j]
sample_to_fold[vi] = k
def predictor(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
fold_preds = []
for k in range(folds):
fold_models[k].eval()
with torch.no_grad():
fold_preds.append(fold_models[k](x).squeeze())
fold_preds = torch.stack(fold_preds, dim=0) # (K, *test_shape)
# Select the fold model prediction for each calibration sample
selected = fold_preds[sample_to_fold] # (n, *test_shape)
r = residuals
if selected.dim() > 1:
r = residuals.unsqueeze(-1)
lower_vals = selected - r
upper_vals = selected + r
# CV+ quantiles (Barber et al., 2021)
lower_sorted = torch.sort(lower_vals, dim=0)[0]
upper_sorted = torch.sort(upper_vals, dim=0)[0]
k_lo = math.ceil(alpha * (n + 1)) - 2
if k_lo < 0:
lower = torch.full_like(lower_sorted[0], float("-inf"))
else:
lower = lower_sorted[k_lo]
k_hi = math.ceil((1 - alpha) * (n + 1)) - 1
if k_hi >= n:
upper = torch.full_like(upper_sorted[0], float("inf"))
else:
upper = upper_sorted[k_hi]
return lower, upper
return predictor
[docs]
@torch.no_grad()
def conformalized_quantile_regression(
quantile_model: torch.nn.Module,
calib_loader: torch.utils.data.DataLoader,
alpha: float,
) -> Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""
Conformalized Quantile Regression (CQR)
— Romano, Patterson, and Candes, *NeurIPS 2019*.
Combines quantile regression with split conformal prediction to produce
adaptive prediction intervals with finite-sample coverage guarantees.
The method works by:
1. Training a quantile regression model to predict lower and upper quantiles
2. Computing conformity scores on a calibration set as the maximum violation
3. Using these scores to adjust the quantile predictions
Args:
quantile_model: A model that outputs (lower_quantile, upper_quantile) predictions.
Expected to output shape (batch_size, 2) where:
- [:, 0] is the lower quantile prediction
- [:, 1] is the upper quantile prediction
calib_loader: DataLoader for calibration data
alpha: Miscoverage rate (e.g., 0.1 for 90% coverage)
Returns:
predictor: Function mapping inputs to (lower, upper) prediction intervals
Reference:
Romano, Patterson, and Candes. "Conformalized Quantile Regression."
NeurIPS 2019. https://arxiv.org/abs/1905.03222
Example:
>>> # Train quantile model (outputs lower and upper quantiles)
>>> quantile_net = QuantileRegressionNet()
>>> # ... train quantile_net ...
>>> predictor = conformalized_quantile_regression(
... quantile_net, calib_loader, alpha=0.1
... )
>>> lower, upper = predictor(test_x)
"""
_validate_alpha(alpha)
quantile_model.eval()
scores = []
for x, y in calib_loader:
# Get quantile predictions (batch_size, 2)
preds = quantile_model(x)
if preds.dim() == 1:
# If model outputs single value, assume it's the median
# and we'll use symmetric intervals
preds = preds.unsqueeze(-1).repeat(1, 2)
q_lo = preds[:, 0] # lower quantile prediction
q_hi = preds[:, 1] # upper quantile prediction
# Conformity score: max of how much y exceeds the interval
# score = max(q_lo - y, y - q_hi)
score_low = q_lo - y
score_high = y - q_hi
score = torch.max(score_low, score_high)
scores.append(score)
# Compute calibrated quantile with finite-sample correction
all_scores = torch.cat(scores)
qhat = _conformal_quantile(all_scores, alpha)
def predictor(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
quantile_model.eval()
with torch.no_grad():
preds = quantile_model(x)
if preds.dim() == 1:
preds = preds.unsqueeze(-1).repeat(1, 2)
# Adjust quantile predictions by calibrated correction
lower = preds[:, 0] - qhat
upper = preds[:, 1] + qhat
return lower.squeeze(), upper.squeeze()
return predictor