Source code for incerto.conformal.methods

"""
incerto.conformal.methods
-------------------------
Stateless helper functions that wrap a *trained* base model and produce
prediction sets or intervals at a user-specified mis-coverage rate α.

Each method adds only what is necessary for conformal inference—no
optimisers, schedulers, or training loops are defined here.
"""

from __future__ import annotations
from typing import Callable, Tuple, List
import torch
import numpy as np

# type alias
Batch = Tuple[torch.Tensor, torch.Tensor]  # (inputs, labels)


[docs] @torch.no_grad() def inductive_conformal( model: torch.nn.Module, calib_loader: torch.utils.data.DataLoader, alpha: float, ) -> Callable[[torch.Tensor], List[torch.Tensor]]: """ Classical Inductive Conformal Prediction (ICP) — Vovk, Gammerman, and Shafer, *Algorithmic Learning in a Random World* (2005). Returns a predictor f̂(x) that outputs a prediction set (classification) or interval (regression) for any new x. """ model.eval() scores = [] for x, y in calib_loader: logits = model(x) conf = torch.softmax(logits, dim=-1) # conformity score: 1 − probability assigned to the true class scores.append(1.0 - conf[torch.arange(len(y)), y]) qhat = torch.quantile(torch.cat(scores), 1.0 - alpha) def predictor(x: torch.Tensor) -> List[torch.Tensor]: logits = model(x) conf = torch.softmax(logits, dim=-1) return [(conf_i >= 1.0 - qhat).nonzero().squeeze(-1) for conf_i in conf] return predictor
[docs] @torch.no_grad() def mondrian_conformal( model: torch.nn.Module, calib_loader: torch.utils.data.DataLoader, alpha: float, partition_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, ) -> Callable[[torch.Tensor], List[torch.Tensor]]: """ Mondrian Conformal Prediction — Papadopoulos, *Reliable Classification with Conformal Predictors* (2008). Allows per-class (or arbitrary partition) calibration to guarantee *conditional* coverage within each partition cell. """ if partition_fn is None: # default: partition by true label partition_fn = lambda x, y: y # noqa: E731 model.eval() parts = {} for x, y in calib_loader: logits = model(x) conf = torch.softmax(logits, dim=-1) part = partition_fn(x, y) for p, c, yy in zip(part, conf, y): scores = parts.setdefault(int(p), []) scores.append(1.0 - c[yy]) qhats = {k: torch.quantile(torch.tensor(v), 1.0 - alpha) for k, v in parts.items()} def predictor(x: torch.Tensor) -> List[torch.Tensor]: logits = model(x) conf = torch.softmax(logits, dim=-1) part = partition_fn(x, logits.argmax(-1)) return [ (ci >= 1.0 - qhats[int(pi)]).nonzero().squeeze(-1) for ci, pi in zip(conf, part) ] return predictor
[docs] @torch.no_grad() def aps( model: torch.nn.Module, calib_loader: torch.utils.data.DataLoader, alpha: float, ) -> Callable[[torch.Tensor], List[torch.Tensor]]: """ Adaptive Prediction Sets (APS) — Romano, Patterson, and Candes, *NeurIPS 2020*. Produces variable-sized sets by thresholding cumulative probability mass up to and including the true label, calibrated on held-out data. """ model.eval() scores = [] for x, y in calib_loader: logits = model(x) probs = torch.softmax(logits, dim=-1) probs_sorted, idx_sorted = torch.sort(probs, descending=True, dim=-1) cumprobs = probs_sorted.cumsum(dim=-1) # Find where true label appears in sorted order ranks = (idx_sorted == y.unsqueeze(-1)).nonzero(as_tuple=True)[1] # Score is cumulative probability up to and including true label scores.append(cumprobs[torch.arange(len(y)), ranks]) qhat = torch.quantile(torch.cat(scores), 1.0 - alpha) def predictor(x: torch.Tensor) -> List[torch.Tensor]: logits = model(x) probs, idx = torch.sort(torch.softmax(logits, dim=-1), descending=True, dim=-1) cumprobs = probs.cumsum(dim=-1) # Include classes until cumulative probability exceeds threshold return [ (idx_i[cumprobs_i <= qhat]).clone().detach() for idx_i, cumprobs_i in zip(idx, cumprobs) ] return predictor
[docs] @torch.no_grad() def raps( model: torch.nn.Module, calib_loader: torch.utils.data.DataLoader, alpha: float, lam: float = 0.0, k_reg: int = 1, ) -> Callable[[torch.Tensor], List[torch.Tensor]]: """ Regularized APS (RAPS) — Tsesmelis et al., *ICML 2021*. Adds ℓ₁ regularisation (λ) and minimum size constraint (k_reg) to APS. """ model.eval() # compute calibration scores following RAPS definition scores = [] for x, y in calib_loader: logits = model(x) probs, idx = torch.sort(torch.softmax(logits, dim=-1), descending=True) rank = (idx == y[:, None]).nonzero()[:, 1] g = probs.cumsum(dim=-1) + lam * torch.arange( 1, probs.size(-1) + 1, device=probs.device ) scores.append(g[torch.arange(len(y)), rank]) qhat = torch.quantile( torch.cat(scores), (1.0 - alpha) * (1 + 1.0 / len(calib_loader.dataset)) ) def predictor(x: torch.Tensor) -> List[torch.Tensor]: logits = model(x) probs, idx = torch.sort(torch.softmax(logits, dim=-1), descending=True) g = probs.cumsum(dim=-1) + lam * torch.arange( 1, probs.size(-1) + 1, device=probs.device ) S = (g <= qhat).long() # enforce minimum size k_reg ks = torch.clamp(k_reg - S.sum(dim=-1), min=0) mask_extra = ( torch.arange(probs.size(-1), device=probs.device)[None] < ks[:, None] ) S = S | mask_extra return [(idx_i[S_i == 1]).clone().detach() for idx_i, S_i in zip(idx, S)] return predictor
# -------- Regression flavours (absolute residual conformity) -------- #
[docs] @torch.no_grad() def jackknife_plus( model_fn: Callable[[torch.utils.data.Dataset], torch.nn.Module], train_dataset: torch.utils.data.Dataset, alpha: float, ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """ Jackknife+ Intervals — Barber, Candès, and Ramdas, *Ann. Stat.* 2021. model_fn: function that re-trains a model on a supplied dataset split. """ n = len(train_dataset) preds = torch.empty((n,)) # ŷ_i^(-i) for i in range(n): leave_one_out = torch.utils.data.Subset( train_dataset, [j for j in range(n) if j != i] ) model = model_fn(leave_one_out) xi, yi = train_dataset[i] preds[i] = model(xi.unsqueeze(0)).squeeze().cpu() residuals = torch.abs(preds - torch.tensor([train_dataset[i][1] for i in range(n)])) q = torch.quantile(residuals, 1.0 - alpha) def predictor(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: full_model = model_fn(train_dataset) mu = full_model(x).squeeze() return mu - q, mu + q return predictor
[docs] @torch.no_grad() def cv_plus( model_fn: Callable[[torch.utils.data.Dataset], torch.nn.Module], train_dataset: torch.utils.data.Dataset, folds: int, alpha: float, ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """ Cross-Validation+ Intervals (CV+) — Barber et al., *JASA* 2021. Offers less pessimistic intervals than Jackknife+ while controlling coverage. """ # split indices n = len(train_dataset) idx = torch.randperm(n) fold_sizes = [(n + i) // folds for i in range(folds)] intervals = [] for k in range(folds): val_idx = idx[sum(fold_sizes[:k]) : sum(fold_sizes[: k + 1])] train_idx = [i for i in idx if i not in val_idx] model = model_fn(torch.utils.data.Subset(train_dataset, train_idx)) Xk = torch.stack([train_dataset[i][0] for i in val_idx]) yk = torch.tensor([train_dataset[i][1] for i in val_idx]) preds = model(Xk).squeeze() intervals.append((yk - preds, yk + preds)) lo, hi = torch.cat([lo for lo, _ in intervals]), torch.cat( [hi for _, hi in intervals] ) q_lo, q_hi = torch.quantile(lo, alpha / 2), torch.quantile(hi, 1 - alpha / 2) def predictor(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # train on full data model = model_fn(train_dataset) mu = model(x).squeeze() return mu + q_lo, mu + q_hi return predictor
[docs] @torch.no_grad() def conformalized_quantile_regression( quantile_model: torch.nn.Module, calib_loader: torch.utils.data.DataLoader, alpha: float, q_low: float | None = None, q_high: float | None = None, ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """ Conformalized Quantile Regression (CQR) — Romano, Patterson, and Candes, *NeurIPS 2019*. Combines quantile regression with split conformal prediction to produce adaptive prediction intervals with finite-sample coverage guarantees. The method works by: 1. Training a quantile regression model to predict lower and upper quantiles 2. Computing conformity scores on a calibration set as the maximum violation 3. Using these scores to adjust the quantile predictions Args: quantile_model: A model that outputs (lower_quantile, upper_quantile) predictions. Expected to output shape (batch_size, 2) where: - [:, 0] is the lower quantile prediction - [:, 1] is the upper quantile prediction calib_loader: DataLoader for calibration data alpha: Miscoverage rate (e.g., 0.1 for 90% coverage) q_low: Lower quantile level (default: alpha/2) q_high: Upper quantile level (default: 1 - alpha/2) Returns: predictor: Function mapping inputs to (lower, upper) prediction intervals Reference: Romano, Patterson, and Candes. "Conformalized Quantile Regression." NeurIPS 2019. https://arxiv.org/abs/1905.03222 Example: >>> # Train quantile model (outputs lower and upper quantiles) >>> quantile_net = QuantileRegressionNet() >>> # ... train quantile_net ... >>> predictor = conformalized_quantile_regression( ... quantile_net, calib_loader, alpha=0.1 ... ) >>> lower, upper = predictor(test_x) """ if q_low is None: q_low = alpha / 2 if q_high is None: q_high = 1 - alpha / 2 quantile_model.eval() scores = [] for x, y in calib_loader: # Get quantile predictions (batch_size, 2) preds = quantile_model(x) if preds.dim() == 1: # If model outputs single value, assume it's the median # and we'll use symmetric intervals preds = preds.unsqueeze(-1).repeat(1, 2) q_lo = preds[:, 0] # lower quantile prediction q_hi = preds[:, 1] # upper quantile prediction # Conformity score: max of how much y exceeds the interval # score = max(q_lo - y, y - q_hi) score_low = q_lo - y score_high = y - q_hi score = torch.max(score_low, score_high) scores.append(score) # Compute calibrated quantile with finite-sample correction all_scores = torch.cat(scores) n_calib = len(all_scores) q_level = np.ceil((n_calib + 1) * (1 - alpha)) / n_calib q_level = min(q_level, 1.0) # ensure valid quantile qhat = torch.quantile(all_scores, q_level) def predictor(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict conformalized intervals for new inputs. Args: x: Input tensor of shape (batch_size, ...) Returns: lower: Lower bound of prediction interval (batch_size,) upper: Upper bound of prediction interval (batch_size,) """ quantile_model.eval() preds = quantile_model(x) if preds.dim() == 1: preds = preds.unsqueeze(-1).repeat(1, 2) # Adjust quantile predictions by calibrated correction lower = preds[:, 0] - qhat upper = preds[:, 1] + qhat return lower.squeeze(), upper.squeeze() return predictor