Source code for incerto.calibration.methods

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression
from torch.distributions import Categorical

from .base import BaseCalibrator


[docs] class IdentityCalibrator(BaseCalibrator): """ No-op calibrator that returns the original softmax probabilities. """
[docs] def fit(self, logits: torch.Tensor, labels: torch.Tensor): # noqa: ARG002 # No parameters to fit return self
[docs] def predict(self, logits: torch.Tensor) -> Categorical: probs = F.softmax(logits, dim=1) return Categorical(probs=probs)
[docs] def state_dict(self) -> dict: """Return empty state dict (no parameters to save).""" return {}
[docs] def load_state_dict(self, state: dict) -> None: """Load state (no-op for identity calibrator).""" pass
def __repr__(self) -> str: return "IdentityCalibrator()"
[docs] class TemperatureScaling(nn.Module, BaseCalibrator): """ Temperature scaling for calibration: scales logits by a learned temperature. """
[docs] def __init__(self, init_temp: float = 1.0): super().__init__() # temperature parameter > 0 self.temperature = nn.Parameter(torch.tensor(init_temp))
[docs] def fit( self, logits: torch.Tensor, labels: torch.Tensor, lr: float = 0.01, max_iters: int = 50, ): """ Fit temperature on validation logits and labels by minimizing NLL. Args: logits: Tensor (n_samples, n_classes). labels: Tensor (n_samples,) with class indices. lr: Learning rate for L-BFGS optimizer. max_iters: Maximum iterations for optimizer. """ # Move to same device device = logits.device self.to(device) labels = labels.to(device) optimizer = torch.optim.LBFGS([self.temperature], lr=lr, max_iter=max_iters) nll = nn.CrossEntropyLoss() def _eval(): optimizer.zero_grad() scaled = logits / self.temperature.clamp(min=1e-6) loss = nll(scaled, labels) loss.backward() return loss optimizer.step(_eval) return self
[docs] def forward(self, logits: torch.Tensor) -> torch.Tensor: # used for direct scaling return logits / self.temperature.clamp(min=1e-6)
[docs] def predict(self, logits: torch.Tensor) -> Categorical: scaled = self.forward(logits) probs = F.softmax(scaled, dim=1) return Categorical(probs=probs)
def __repr__(self) -> str: return f"TemperatureScaling(temperature={self.temperature.item():.4f})"
class IsotonicRegressionCalibrator(BaseCalibrator): """ Multi-class isotonic regression calibration (per-class fitting). """ def __init__(self, out_of_bounds: str = "clip"): # out_of_bounds: 'clip' or 'nan' self.out_of_bounds = out_of_bounds self.calibrators = [] self.n_classes = 0 def fit(self, logits: torch.Tensor, labels: torch.Tensor): probs = F.softmax(logits, dim=1).cpu().detach().numpy() labels_np = labels.cpu().detach().numpy() n_samples, n_classes = probs.shape self.n_classes = n_classes self.calibrators = [] for k in range(n_classes): ir = IsotonicRegression(out_of_bounds=self.out_of_bounds) ir.fit(probs[:, k], (labels_np == k).astype(int)) self.calibrators.append(ir) return self def predict(self, logits: torch.Tensor) -> Categorical: probs = F.softmax(logits, dim=1).cpu().detach().numpy() calibrated = np.zeros_like(probs) for k, ir in enumerate(self.calibrators): calibrated[:, k] = ir.predict(probs[:, k]) calibrated = torch.tensor(calibrated, device=logits.device, dtype=torch.float32) # re-normalize calibrated = calibrated / calibrated.sum(dim=1, keepdim=True) return Categorical(probs=calibrated) def state_dict(self) -> dict: """Save isotonic regression calibrators.""" import pickle return { "n_classes": self.n_classes, "out_of_bounds": self.out_of_bounds, "calibrators": pickle.dumps(self.calibrators), } def load_state_dict(self, state: dict) -> None: """Load isotonic regression calibrators.""" import pickle from ..exceptions import SerializationError try: self.n_classes = state["n_classes"] self.out_of_bounds = state["out_of_bounds"] self.calibrators = pickle.loads(state["calibrators"]) except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e def __repr__(self) -> str: return f"IsotonicRegressionCalibrator(n_classes={self.n_classes}, out_of_bounds='{self.out_of_bounds}')" class HistogramBinningCalibrator(BaseCalibrator): """ Histogram binning calibration: bins predicted probabilities and uses empirical frequencies. """ def __init__(self, n_bins: int = 10): self.n_bins = n_bins self.bin_edges: list = [] self.bin_true_rates: list = [] def fit(self, logits: torch.Tensor, labels: torch.Tensor): probs = F.softmax(logits, dim=1).cpu().detach().numpy() labels_np = labels.cpu().detach().numpy() _, n_classes = probs.shape self.bin_edges = [] self.bin_true_rates = [] for k in range(n_classes): pk = probs[:, k] edges = np.linspace(0.0, 1.0, self.n_bins + 1) bin_ids = np.digitize(pk, edges, right=True) - 1 true_rates = np.zeros(self.n_bins) for b in range(self.n_bins): idx = bin_ids == b if idx.sum() > 0: true_rates[b] = (labels_np[idx] == k).sum() / idx.sum() else: true_rates[b] = 0.0 self.bin_edges.append(edges) self.bin_true_rates.append(true_rates) return self def predict(self, logits: torch.Tensor) -> Categorical: probs = F.softmax(logits, dim=1).cpu().detach().numpy() n_samples, n_classes = probs.shape calibrated = np.zeros_like(probs) for k in range(n_classes): edges = self.bin_edges[k] rates = self.bin_true_rates[k] pk = probs[:, k] bin_ids = np.digitize(pk, edges, right=True) - 1 bin_ids = np.clip(bin_ids, 0, len(rates) - 1) calibrated[:, k] = rates[bin_ids] calibrated = torch.tensor(calibrated, device=logits.device, dtype=torch.float32) calibrated = calibrated / calibrated.sum(dim=1, keepdim=True) return Categorical(probs=calibrated) def state_dict(self) -> dict: """Save histogram binning state.""" return { "n_bins": self.n_bins, "bin_edges": self.bin_edges, "bin_true_rates": self.bin_true_rates, } def load_state_dict(self, state: dict) -> None: """Load histogram binning state.""" from ..exceptions import SerializationError try: self.n_bins = state["n_bins"] self.bin_edges = state["bin_edges"] self.bin_true_rates = state["bin_true_rates"] except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e def __repr__(self) -> str: n_classes = len(self.bin_edges) if self.bin_edges else 0 return ( f"HistogramBinningCalibrator(n_bins={self.n_bins}, n_classes={n_classes})" ) class PlattScalingCalibrator(BaseCalibrator): """ Platt scaling (logistic regression) calibration per class (one-vs-rest). """ def __init__(self): self.models: list = [] self.n_classes: int = 0 def fit(self, logits: torch.Tensor, labels: torch.Tensor): probs = F.softmax(logits, dim=1).cpu().detach().numpy() labels_np = labels.cpu().detach().numpy() _, n_classes = probs.shape self.n_classes = n_classes self.models = [] for k in range(n_classes): lr = LogisticRegression() lr.fit(probs[:, [k]], (labels_np == k).astype(int)) self.models.append(lr) return self def predict(self, logits: torch.Tensor) -> Categorical: probs = F.softmax(logits, dim=1).cpu().detach().numpy() calibrated = np.zeros_like(probs) for k, lr in enumerate(self.models): calibrated[:, k] = lr.predict_proba(probs[:, [k]])[:, 1] calibrated = torch.tensor(calibrated, device=logits.device, dtype=torch.float32) calibrated = calibrated / calibrated.sum(dim=1, keepdim=True) return Categorical(probs=calibrated) def state_dict(self) -> dict: """Save Platt scaling models.""" import pickle return { "n_classes": self.n_classes, "models": pickle.dumps(self.models), } def load_state_dict(self, state: dict) -> None: """Load Platt scaling models.""" import pickle from ..exceptions import SerializationError try: self.n_classes = state["n_classes"] self.models = pickle.loads(state["models"]) except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e def __repr__(self) -> str: return f"PlattScalingCalibrator(n_classes={self.n_classes})"
[docs] class VectorScaling(nn.Module, BaseCalibrator): """ Vector Scaling (Guo et al., 2017). Extends temperature scaling by learning a different temperature parameter for each class: z_scaled = z / T where T is a vector. """
[docs] def __init__(self, n_classes: int): super().__init__() self.temperature = nn.Parameter(torch.ones(n_classes))
[docs] def fit( self, logits: torch.Tensor, labels: torch.Tensor, lr: float = 0.01, max_iters: int = 50, ): """ Fit vector of temperatures on validation logits and labels. Args: logits: Tensor (n_samples, n_classes). labels: Tensor (n_samples,) with class indices. lr: Learning rate for L-BFGS optimizer. max_iters: Maximum iterations for optimizer. """ device = logits.device self.to(device) labels = labels.to(device) optimizer = torch.optim.LBFGS([self.temperature], lr=lr, max_iter=max_iters) nll = nn.CrossEntropyLoss() def _eval(): optimizer.zero_grad() scaled = logits / self.temperature.clamp(min=1e-6) loss = nll(scaled, labels) loss.backward() return loss optimizer.step(_eval) return self
[docs] def forward(self, logits: torch.Tensor) -> torch.Tensor: return logits / self.temperature.clamp(min=1e-6)
[docs] def predict(self, logits: torch.Tensor) -> Categorical: scaled = self.forward(logits) probs = F.softmax(scaled, dim=1) return Categorical(probs=probs)
def __repr__(self) -> str: temps = self.temperature.detach().cpu().numpy() return f"VectorScaling(n_classes={len(temps)}, temperature_range=[{temps.min():.4f}, {temps.max():.4f}])"
[docs] class MatrixScaling(nn.Module, BaseCalibrator): """ Matrix Scaling (Guo et al., 2017). Most general affine transformation: z_scaled = W @ z + b where W is a learned matrix and b is a learned bias vector. """
[docs] def __init__(self, n_classes: int): super().__init__() self.weight = nn.Parameter(torch.eye(n_classes)) self.bias = nn.Parameter(torch.zeros(n_classes))
[docs] def fit( self, logits: torch.Tensor, labels: torch.Tensor, lr: float = 0.01, max_iters: int = 50, ): """ Fit transformation matrix and bias on validation logits. Args: logits: Tensor (n_samples, n_classes). labels: Tensor (n_samples,) with class indices. lr: Learning rate for L-BFGS optimizer. max_iters: Maximum iterations for optimizer. """ device = logits.device self.to(device) labels = labels.to(device) optimizer = torch.optim.LBFGS( [self.weight, self.bias], lr=lr, max_iter=max_iters ) nll = nn.CrossEntropyLoss() def _eval(): optimizer.zero_grad() scaled = logits @ self.weight.T + self.bias loss = nll(scaled, labels) loss.backward() return loss optimizer.step(_eval) return self
[docs] def forward(self, logits: torch.Tensor) -> torch.Tensor: return logits @ self.weight.T + self.bias
[docs] def predict(self, logits: torch.Tensor) -> Categorical: scaled = self.forward(logits) probs = F.softmax(scaled, dim=1) return Categorical(probs=probs)
def __repr__(self) -> str: n_classes = self.weight.shape[0] return f"MatrixScaling(n_classes={n_classes})"
[docs] class DirichletCalibrator(nn.Module, BaseCalibrator): """ Dirichlet Calibration (Kull et al., 2019). Maps logits to Dirichlet distribution parameters using a linear transformation, providing a more flexible calibration than temperature scaling. Reference: Kull et al., "Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with Dirichlet calibration" (NeurIPS 2019) Args: n_classes: Number of classes mu: Regularization parameter (default: None for unregularized) """
[docs] def __init__(self, n_classes: int, mu: float = None): super().__init__() self.n_classes = n_classes self.mu = mu # Learnable parameters for Dirichlet self.weight = nn.Parameter(torch.eye(n_classes)) self.bias = nn.Parameter(torch.zeros(n_classes))
[docs] def fit( self, logits: torch.Tensor, labels: torch.Tensor, lr: float = 0.01, max_iters: int = 100, ): """ Fit Dirichlet parameters on validation data. Args: logits: Validation logits (N, C) labels: Validation labels (N,) lr: Learning rate max_iters: Maximum optimization iterations """ device = logits.device self.to(device) labels = labels.to(device) # Convert to one-hot y_one_hot = F.one_hot(labels, self.n_classes).float() optimizer = torch.optim.LBFGS( [self.weight, self.bias], lr=lr, max_iter=max_iters ) def _eval(): optimizer.zero_grad() # Transform logits transformed = logits @ self.weight.T + self.bias # Softmax to get Dirichlet mean probs = F.softmax(transformed, dim=1) # Dirichlet log-likelihood (simplified) loss = F.cross_entropy(transformed, labels) # Add regularization if specified if self.mu is not None: reg = self.mu * ( torch.norm(self.weight - torch.eye(self.n_classes, device=device)) ** 2 + torch.norm(self.bias) ** 2 ) loss = loss + reg loss.backward() return loss optimizer.step(_eval) return self
[docs] def forward(self, logits: torch.Tensor) -> torch.Tensor: """Apply Dirichlet transformation.""" return logits @ self.weight.T + self.bias
[docs] def predict(self, logits: torch.Tensor) -> Categorical: """Get calibrated predictions.""" transformed = self.forward(logits) probs = F.softmax(transformed, dim=1) return Categorical(probs=probs)
def __repr__(self) -> str: mu_str = f"{self.mu:.4f}" if self.mu is not None else "None" return f"DirichletCalibrator(n_classes={self.n_classes}, mu={mu_str})"
[docs] class BetaCalibrator(BaseCalibrator): """ Beta Calibration for binary classification (Kull et al., 2017). Fits a Beta distribution to map uncalibrated probabilities to calibrated probabilities. More flexible than Platt scaling. Reference: Kull et al., "Beta calibration: a well-founded and easily implemented improvement on logistic calibration" (AISTATS 2017) Args: method: Fitting method ('mle' or 'map') """
[docs] def __init__(self, method: str = "mle"): self.method = method self.a = None # Beta parameter alpha self.b = None # Beta parameter beta self.map_params = None # Mapping parameters
[docs] def fit(self, logits: torch.Tensor, labels: torch.Tensor): """ Fit Beta calibration on binary classification data. Falls back to isotonic regression for multiclass. Args: logits: Validation logits (N, 2) or (N, C) for multiclass labels: Binary labels (N,) """ # Check if binary or multiclass if logits.dim() == 2 and logits.shape[1] > 2: # Multiclass: fallback to isotonic regression self.is_binary = False self.calibrator = IsotonicRegressionCalibrator() self.calibrator.fit(logits, labels) return self self.is_binary = True # Convert to probabilities if logits.dim() == 2: # Multi-class format (binary) probs = F.softmax(logits, dim=1)[:, 1] else: # Binary logits probs = torch.sigmoid(logits) probs_np = probs.cpu().detach().numpy() labels_np = labels.cpu().detach().numpy() # Fit using sklearn's calibration from sklearn.isotonic import IsotonicRegression # Use isotonic regression as a robust Beta approximation self.calibrator = IsotonicRegression(out_of_bounds="clip") self.calibrator.fit(probs_np, labels_np) return self
[docs] def predict(self, logits: torch.Tensor) -> Categorical: """Get calibrated predictions.""" # Check if multiclass fallback if not self.is_binary: return self.calibrator.predict(logits) # Binary classification # Convert to probabilities if logits.dim() == 2: probs = F.softmax(logits, dim=1)[:, 1] else: probs = torch.sigmoid(logits) probs_np = probs.cpu().detach().numpy() # Apply calibration calibrated_probs = self.calibrator.predict(probs_np) calibrated_probs = torch.tensor( calibrated_probs, device=logits.device, dtype=torch.float32 ) # Stack for binary classification probs_both = torch.stack([1 - calibrated_probs, calibrated_probs], dim=1) return Categorical(probs=probs_both)
[docs] def state_dict(self) -> dict: """Save Beta calibrator state.""" import pickle return { "method": self.method, "a": self.a, "b": self.b, "map_params": self.map_params, "is_binary": getattr(self, "is_binary", None), "calibrator": pickle.dumps(getattr(self, "calibrator", None)), }
[docs] def load_state_dict(self, state: dict) -> None: """Load Beta calibrator state.""" import pickle from ..exceptions import SerializationError try: self.method = state["method"] self.a = state["a"] self.b = state["b"] self.map_params = state["map_params"] self.is_binary = state["is_binary"] self.calibrator = pickle.loads(state["calibrator"]) except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e
def __repr__(self) -> str: return f"BetaCalibrator(method='{self.method}')"