Source code for incerto.sp.methods

"""
Selective-prediction algorithms and helper layers.

All methods expose a `forward(x, return_confidence=False)` signature
and a `.reject(confidence, threshold)` utility that returns a boolean
mask indicating which samples are *rejected* (i.e. deferred).
"""

from __future__ import annotations
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import BaseSelectivePredictor


# ----------------------------------------------------------------------
#                         1. Softmax-Threshold (MSP)
# ----------------------------------------------------------------------
[docs] class SoftmaxThreshold(BaseSelectivePredictor): """Classical confidence-thresholding à la Chow (1957)."""
[docs] def __init__(self, backbone: nn.Module): super().__init__() self.backbone = backbone
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 return self.backbone(x)
# ---------------------------------------------------------------------- # 2. Deep Gambler # ----------------------------------------------------------------------
[docs] class DeepGambler(BaseSelectivePredictor): """ Add an extra *abstain* logit and train with the gambler's loss: L = −log( (1 − r) * p_y + r / C ) where `r` is the reserve (confidence to abstain) and `C` is the number of classes. """
[docs] def __init__(self, backbone: nn.Module, num_classes: int): super().__init__() self.backbone = backbone # small linear head that outputs C + 1 logits (extra abstain) last_dim = list(backbone.parameters())[-1].shape[0] self.head = nn.Linear(last_dim, num_classes + 1)
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor: feats = self.backbone(x) return self.head(feats)
[docs] def confidence_from_logits(self, logits: torch.Tensor) -> torch.Tensor: *class_logits, abstain_logit = logits.split_with_sizes( [logits.size(-1) - 1, 1], dim=-1 ) class_logits = torch.cat(class_logits, dim=-1) # confidence is 1 − probability of abstain probs = F.softmax(torch.cat([class_logits, abstain_logit], dim=-1), dim=-1) return 1.0 - probs[..., -1]
# ---------------------------------------------------------------------- # 3. SelectiveNet # ----------------------------------------------------------------------
[docs] class SelectiveNet(BaseSelectivePredictor): """ Implementation of SelectiveNet (Geifman & El-Yaniv, 2019). The model outputs: * h(x): class logits * g(x): selection probability """
[docs] def __init__( self, backbone: nn.Module, num_classes: int, hidden: int = 128, alpha: float = 0.5, ): super().__init__() self.backbone = backbone last_dim = list(backbone.parameters())[-1].shape[0] self.h = nn.Linear(last_dim, num_classes) self.g = nn.Sequential( nn.Linear(last_dim, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, 1), nn.Sigmoid(), ) self.alpha = alpha # coverage target in loss
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor: feats = self.backbone(x) return self.h(feats)
[docs] def forward( # type: ignore[override] self, x: torch.Tensor, *, return_confidence: bool = False, ): feats = self.backbone(x) logits = self.h(feats) sel_prob = self.g(feats).squeeze(-1) # confidence ∈ [0,1] if return_confidence: return logits, sel_prob return logits
[docs] def confidence_from_logits(self, logits): # unused (override forward) raise NotImplementedError
# ---------------------------------------------------------------------- # 4. Self-Adaptive Training (SAT) # ----------------------------------------------------------------------
[docs] class SelfAdaptiveTraining(BaseSelectivePredictor): """ Self-Adaptive Training for better calibration and selective prediction. Trains with adaptive soft labels that blend ground truth and model predictions: y_adaptive = (1 - alpha) * y_hard + alpha * softmax(logits) This improves calibration naturally during training, making the model better at knowing when to reject/abstain on uncertain samples. Reference: Huang et al., "Self-Adaptive Training: beyond Empirical Risk Minimization" NeurIPS 2020. """
[docs] def __init__( self, backbone: nn.Module, num_classes: int, alpha_start: float = 0.0, alpha_end: float = 0.9, warmup_epochs: int = 5, ): """ Args: backbone: The base model to train num_classes: Number of classes alpha_start: Initial alpha value (0 = pure hard labels) alpha_end: Final alpha value (higher = more self-supervision) warmup_epochs: Number of epochs before starting SAT """ super().__init__() self.backbone = backbone self.num_classes = num_classes self.alpha_start = alpha_start self.alpha_end = alpha_end self.warmup_epochs = warmup_epochs self.current_epoch = 0
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x)
[docs] def get_alpha(self, epoch: int, total_epochs: int) -> float: """ Compute current alpha value based on training progress. Args: epoch: Current epoch number total_epochs: Total number of training epochs Returns: Alpha value for blending hard and soft labels """ if epoch < self.warmup_epochs: return self.alpha_start # Linear schedule from alpha_start to alpha_end progress = (epoch - self.warmup_epochs) / max( total_epochs - self.warmup_epochs, 1 ) alpha = self.alpha_start + (self.alpha_end - self.alpha_start) * progress return min(alpha, self.alpha_end)
[docs] def sat_loss( self, logits: torch.Tensor, targets: torch.Tensor, alpha: float, ) -> torch.Tensor: """ Compute Self-Adaptive Training loss. Args: logits: Model predictions (batch_size, num_classes) targets: Ground truth labels (batch_size,) alpha: Mixing coefficient for soft labels Returns: SAT loss value """ # Standard cross-entropy with hard labels if alpha == 0.0: return F.cross_entropy(logits, targets) # Create one-hot encoding of hard labels hard_labels = F.one_hot(targets, num_classes=self.num_classes).float() # Get soft labels from current model predictions (detached to avoid gradient flow) with torch.no_grad(): soft_labels = F.softmax(logits.detach(), dim=1) # Blend hard and soft labels adaptive_labels = (1 - alpha) * hard_labels + alpha * soft_labels # Compute cross-entropy with adaptive labels log_probs = F.log_softmax(logits, dim=1) loss = -(adaptive_labels * log_probs).sum(dim=1).mean() return loss
# ---------------------------------------------------------------------- # FACTORY UTIL # ---------------------------------------------------------------------- def make(selector: str, *args, **kwargs) -> BaseSelectivePredictor: """Quick factory: `make('msp', backbone)` or `make('selectivenet', ...)`.""" selector = selector.lower() if selector in {"msp", "softmax", "threshold"}: return SoftmaxThreshold(*args, **kwargs) if selector in {"selectivenet", "sn"}: return SelectiveNet(*args, **kwargs) if selector in {"gambler", "deepgambler"}: return DeepGambler(*args, **kwargs) if selector in {"sat", "selfadaptive", "self-adaptive"}: return SelfAdaptiveTraining(*args, **kwargs) raise ValueError(f"Unknown selector {selector!r}")