Source code for incerto.calibration.training

"""
Training-time calibration methods.

Methods that improve calibration during training, rather than
post-hoc after training is complete.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


[docs] class LabelSmoothingLoss(nn.Module): """ Label Smoothing for improved calibration. Softens hard labels to prevent overconfidence: y_smooth = (1 - α) * y_hard + α / K where α is the smoothing parameter and K is the number of classes. Reference: Szegedy et al. "Rethinking the Inception Architecture" (CVPR 2016) Müller et al. "When Does Label Smoothing Help?" (NeurIPS 2019) Args: smoothing: Smoothing parameter (default: 0.1) reduction: Reduction method ('mean', 'sum', or 'none') Example: >>> criterion = LabelSmoothingLoss(smoothing=0.1) >>> logits = torch.randn(32, 10) >>> targets = torch.randint(0, 10, (32,)) >>> loss = criterion(logits, targets) """
[docs] def __init__(self, smoothing: float = 0.1, reduction: str = "mean"): super().__init__() self.smoothing = smoothing self.reduction = reduction
[docs] def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute label smoothing loss. Args: logits: Model outputs (N, C) targets: Ground truth labels (N,) Returns: Loss value """ num_classes = logits.size(-1) log_probs = F.log_softmax(logits, dim=-1) # Create smooth labels targets_one_hot = F.one_hot(targets, num_classes).float() smooth_targets = ( 1 - self.smoothing ) * targets_one_hot + self.smoothing / num_classes # Compute loss loss = -(smooth_targets * log_probs).sum(dim=-1) if self.reduction == "mean": return loss.mean() elif self.reduction == "sum": return loss.sum() else: return loss
[docs] class FocalLoss(nn.Module): """ Focal Loss for handling hard examples. Down-weights easy examples to focus on hard ones: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t) The (1 - p_t)^γ modulating factor reduces the loss contribution from easy examples and extends the range where examples receive low loss. Reference: Lin et al. "Focal Loss for Dense Object Detection" (ICCV 2017) Args: alpha: Weighting factor (default: 1.0) gamma: Focusing parameter (default: 2.0) reduction: Reduction method ('mean', 'sum', or 'none') Example: >>> criterion = FocalLoss(gamma=2.0) >>> logits = torch.randn(32, 10) >>> targets = torch.randint(0, 10, (32,)) >>> loss = criterion(logits, targets) """
[docs] def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = "mean"): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction
[docs] def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute focal loss. Args: logits: Model outputs (N, C) targets: Ground truth labels (N,) Returns: Loss value """ ce_loss = F.cross_entropy(logits, targets, reduction="none") p_t = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - p_t) ** self.gamma * ce_loss if self.reduction == "mean": return focal_loss.mean() elif self.reduction == "sum": return focal_loss.sum() else: return focal_loss
[docs] class ConfidencePenalty(nn.Module): """ Confidence Penalty to prevent overconfidence. Regularizes model predictions to have higher entropy: Loss = CE + β * (-H(p)) where H(p) is the entropy of predictions. The negative entropy term penalizes confident predictions. Reference: Pereyra et al. "Regularizing Neural Networks by Penalizing Confident Output Distributions" (ICLR 2017) Args: beta: Penalty weight (default: 0.1) Example: >>> criterion = ConfidencePenalty(beta=0.1) >>> logits = torch.randn(32, 10) >>> targets = torch.randint(0, 10, (32,)) >>> loss = criterion(logits, targets) """
[docs] def __init__(self, beta: float = 0.1): super().__init__() self.beta = beta
[docs] def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute confidence penalty loss. Args: logits: Model outputs (N, C) targets: Ground truth labels (N,) Returns: Loss value """ # Standard cross-entropy ce_loss = F.cross_entropy(logits, targets) # Confidence penalty (negative entropy) probs = F.softmax(logits, dim=-1) log_probs = F.log_softmax(logits, dim=-1) entropy = -(probs * log_probs).sum(dim=-1).mean() confidence_penalty = -entropy total_loss = ce_loss + self.beta * confidence_penalty return total_loss
def evidential_loss( evidence: torch.Tensor, targets: torch.Tensor, num_classes: int, epoch: int, num_epochs: int, kl_weight: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Evidential Deep Learning loss. Learns Dirichlet distributions over class probabilities, enabling second-order uncertainty estimation. Loss = MSE(p, y) + λ * KL[Dir(α_tilde) || Dir(1)] where α = evidence + 1 (Dirichlet parameters) and λ is annealed from 0 to 1 over training. Reference: Sensoy et al. "Evidential Deep Learning to Quantify Classification Uncertainty" (NeurIPS 2018) Args: evidence: Non-negative evidence values (N, C) targets: Ground truth labels (N,) num_classes: Number of classes epoch: Current epoch num_epochs: Total number of epochs kl_weight: Maximum KL weight (default: 1.0) Returns: tuple of (total_loss, mse_loss, kl_loss) Example: >>> evidence = F.softplus(model(x)) # Ensure non-negative >>> loss, mse, kl = evidential_loss(evidence, targets, 10, epoch, total_epochs) """ # Dirichlet parameters alpha = evidence + 1 S = alpha.sum(dim=1, keepdim=True) # One-hot encode targets targets_one_hot = F.one_hot(targets, num_classes).float() # MSE loss between predicted probabilities and true labels prob = alpha / S mse_loss = ((targets_one_hot - prob) ** 2).sum(dim=1).mean() # KL divergence regularization # Encourages high uncertainty on wrong predictions alpha_tilde = targets_one_hot + (1 - targets_one_hot) * alpha S_tilde = alpha_tilde.sum(dim=1, keepdim=True) # KL[Dir(alpha_tilde) || Dir(1)] first_term = torch.lgamma(S_tilde) - torch.lgamma(alpha_tilde).sum( dim=1, keepdim=True ) second_term = ( (alpha_tilde - 1) * (torch.digamma(alpha_tilde) - torch.digamma(S_tilde)) ).sum(dim=1, keepdim=True) kl_loss = (first_term + second_term).mean() # Anneal KL coefficient from 0 to kl_weight kl_coeff = min(kl_weight, epoch / (num_epochs * 0.5)) total_loss = mse_loss + kl_coeff * kl_loss return total_loss, mse_loss, kl_loss
[docs] class TemperatureAwareTraining(nn.Module): """ Temperature-aware training with learnable temperature. Instead of post-hoc temperature scaling, learns the temperature parameter during training for better calibration. Args: backbone: Base neural network init_temp: Initial temperature (default: 1.0) learn_temp: Whether to learn temperature (default: True) Example: >>> backbone = ResNet18(num_classes=10) >>> model = TemperatureAwareTraining(backbone) >>> logits = model(x) >>> loss = F.cross_entropy(logits, targets) """
[docs] def __init__( self, backbone: nn.Module, init_temp: float = 1.0, learn_temp: bool = True, ): super().__init__() self.backbone = backbone self.temperature = nn.Parameter( torch.ones(1) * init_temp, requires_grad=learn_temp, )
[docs] def forward(self, x: torch.Tensor, return_unscaled: bool = False) -> torch.Tensor: """ Forward pass with temperature scaling. Args: x: Input tensor return_unscaled: If True, return unscaled logits Returns: Temperature-scaled logits """ logits = self.backbone(x) if return_unscaled: return logits # Apply temperature scaling scaled_logits = logits / self.temperature return scaled_logits
def get_uncertainty_from_evidence(evidence: torch.Tensor, num_classes: int) -> dict: """ Compute uncertainty measures from evidential outputs. Args: evidence: Non-negative evidence values (N, C) num_classes: Number of classes Returns: Dictionary with: - alpha: Dirichlet parameters (N, C) - belief: Predicted probabilities (N, C) - uncertainty: Total uncertainty / vacuity (N, 1) - epistemic: Epistemic uncertainty (N, 1) Example: >>> evidence = F.softplus(model(x)) >>> uncertainty = get_uncertainty_from_evidence(evidence, num_classes=10) >>> total_unc = uncertainty['uncertainty'] """ alpha = evidence + 1 S = alpha.sum(dim=1, keepdim=True) # Predicted probabilities belief = alpha / S # Total uncertainty (vacuity) uncertainty = num_classes / S # Epistemic uncertainty (model uncertainty) epistemic = (alpha * (S - alpha)) / (S * S * (S + 1)) epistemic = epistemic.sum(dim=1, keepdim=True) return { "alpha": alpha, "belief": belief, "uncertainty": uncertainty, "epistemic": epistemic, } __all__ = [ "LabelSmoothingLoss", "FocalLoss", "ConfidencePenalty", "evidential_loss", "get_uncertainty_from_evidence", "TemperatureAwareTraining", ]