"""
Selective-prediction algorithms and helper layers.
All methods expose a `forward(x, return_confidence=False)` signature
and a `.reject(confidence, threshold)` utility that returns a boolean
mask indicating which samples are *rejected* (i.e. deferred).
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import BaseSelectivePredictor
def _infer_output_dim(backbone: nn.Module) -> int:
"""Infer output dimension via a probe forward pass."""
was_training = backbone.training
backbone.eval()
# Find an input dimension from the first parameter
first_param = next(backbone.parameters())
in_features = first_param.shape[-1]
with torch.no_grad():
probe = torch.zeros(1, in_features, device=first_param.device)
out = backbone(probe)
if was_training:
backbone.train()
return out.shape[-1]
# ----------------------------------------------------------------------
# 1. Softmax-Threshold (MSP)
# ----------------------------------------------------------------------
[docs]
class SoftmaxThreshold(BaseSelectivePredictor):
"""Classical confidence-thresholding à la Chow (1957)."""
[docs]
def __init__(self, backbone: nn.Module):
super().__init__()
self.backbone = backbone
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401
return self.backbone(x)
# ----------------------------------------------------------------------
# 2. Deep Gambler
# ----------------------------------------------------------------------
[docs]
class DeepGambler(BaseSelectivePredictor):
"""
Add an extra *abstain* logit and train with the gambler's loss:
L = −log( p_y + r / o )
where ``o`` is the reward for a correct prediction, ``p_y`` is the
probability assigned to the true class, and ``r`` = P(abstain).
Reference:
Ziyin et al., "Deep Gamblers: Learning to Abstain with Portfolio
Theory", NeurIPS 2019.
"""
[docs]
def __init__(
self, backbone: nn.Module, num_classes: int, num_features: int | None = None
):
super().__init__()
self.backbone = backbone
# small linear head that outputs C + 1 logits (extra abstain)
last_dim = (
num_features if num_features is not None else _infer_output_dim(backbone)
)
self.head = nn.Linear(last_dim, num_classes + 1)
# Initialise the abstain logit bias to a negative value so the model
# starts by predicting classes rather than collapsing to always-abstain.
with torch.no_grad():
self.head.bias[-1] = -3.0
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor:
feats = self.backbone(x)
return self.head(feats)
[docs]
def confidence_from_logits(self, logits: torch.Tensor) -> torch.Tensor:
# confidence is 1 − P(abstain), where abstain is the last logit
probs = F.softmax(logits, dim=-1)
return 1.0 - probs[..., -1]
[docs]
@staticmethod
def gambler_loss(
logits: torch.Tensor,
targets: torch.Tensor,
reward: float = 2.2,
) -> torch.Tensor:
"""
Gambler's loss from Ziyin et al. (NeurIPS 2019).
The loss encourages the model to either predict correctly or abstain:
L = −log( p_y + r / o )
where ``o`` = reward, ``p_y`` = P(true class), ``r`` = P(abstain).
Higher ``reward`` penalises abstention more, pushing towards prediction.
Args:
logits: Model output of shape (batch, num_classes + 1).
Last column is the abstain logit.
targets: Ground-truth class labels of shape (batch,).
reward: Reward for correct prediction (called *o* in the paper).
Must be > 1. Higher values discourage abstention.
Returns:
Scalar loss.
"""
probs = F.softmax(logits, dim=-1)
num_classes = logits.size(-1) - 1
class_probs = probs[:, :num_classes]
reserve = probs[:, -1] # P(abstain)
# Probability assigned to the true class
p_target = class_probs.gather(1, targets.unsqueeze(1)).squeeze(1)
# Gambler's loss: −log( p_target + reserve / o )
loss = -torch.log(p_target + reserve / reward + 1e-9)
return loss.mean()
# ----------------------------------------------------------------------
# 3. SelectiveNet
# ----------------------------------------------------------------------
[docs]
class SelectiveNet(BaseSelectivePredictor):
"""
Implementation of SelectiveNet (Geifman & El-Yaniv, 2019).
The model outputs:
* h(x): class logits
* g(x): selection probability
"""
[docs]
def __init__(
self,
backbone: nn.Module,
num_classes: int,
hidden: int = 128,
alpha: float = 0.5,
lam: float = 32.0,
num_features: int | None = None,
):
super().__init__()
self.backbone = backbone
last_dim = (
num_features if num_features is not None else _infer_output_dim(backbone)
)
self.h = nn.Linear(last_dim, num_classes)
self.g = nn.Sequential(
nn.Linear(last_dim, hidden),
nn.ReLU(inplace=True),
nn.Linear(hidden, 1),
nn.Sigmoid(),
)
self.alpha = alpha # coverage target in loss
self.lam = lam # penalty coefficient for coverage constraint
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor:
feats = self.backbone(x)
return self.h(feats)
[docs]
def forward( # type: ignore[override]
self,
x: torch.Tensor,
*,
return_confidence: bool = False,
):
feats = self.backbone(x)
logits = self.h(feats)
sel_prob = self.g(feats).squeeze(-1) # confidence ∈ [0,1]
if return_confidence:
return logits, sel_prob
return logits
[docs]
def confidence_from_logits(self, logits):
"""Not applicable — SelectiveNet uses a dedicated selection head g(x).
Use ``forward(x, return_confidence=True)`` instead.
"""
raise NotImplementedError(
"SelectiveNet uses a dedicated selection head (g). "
"Call forward(x, return_confidence=True) to get confidence."
)
[docs]
def selective_loss(
self,
logits: torch.Tensor,
targets: torch.Tensor,
selection: torch.Tensor,
coverage_target: float | None = None,
) -> torch.Tensor:
"""
SelectiveNet loss from Geifman & El-Yaniv (ICML 2019).
Combines a selection-weighted prediction loss with a quadratic
coverage penalty:
L = L_selective + λ * max(0, c − Φ)²
where ``L_selective`` is the cross-entropy weighted by selection
probabilities, ``c`` is the coverage target (``self.alpha``),
``Φ`` is the empirical coverage, and ``λ`` is the penalty
coefficient (``self.lam``).
Args:
logits: Class logits of shape (batch, num_classes).
targets: Ground-truth labels of shape (batch,).
selection: Selection probabilities g(x) of shape (batch,)
in [0, 1], as returned by
``forward(x, return_confidence=True)``.
coverage_target: Desired coverage. Defaults to ``self.alpha``
set at construction time.
Returns:
Scalar loss.
"""
c = coverage_target if coverage_target is not None else self.alpha
# Selection-weighted cross-entropy
per_sample_loss = F.cross_entropy(logits, targets, reduction="none")
empirical_coverage = selection.mean()
selective_loss = (per_sample_loss * selection).mean() / (
empirical_coverage + 1e-9
)
# Quadratic coverage penalty
coverage_penalty = self.lam * torch.clamp(c - empirical_coverage, min=0.0) ** 2
return selective_loss + coverage_penalty
# ----------------------------------------------------------------------
# 4. Self-Adaptive Training (SAT)
# ----------------------------------------------------------------------
[docs]
class SelfAdaptiveTraining(BaseSelectivePredictor):
"""
Self-Adaptive Training for better calibration and selective prediction.
Trains with adaptive soft labels that blend ground truth and model predictions:
y_adaptive = (1 - alpha) * y_hard + alpha * softmax(logits)
This improves calibration naturally during training, making the model better
at knowing when to reject/abstain on uncertain samples.
Reference:
Huang et al., "Self-Adaptive Training: beyond Empirical Risk Minimization"
NeurIPS 2020.
"""
[docs]
def __init__(
self,
backbone: nn.Module,
num_classes: int,
alpha_start: float = 0.0,
alpha_end: float = 0.9,
warmup_epochs: int = 5,
):
"""
Args:
backbone: The base model to train
num_classes: Number of classes
alpha_start: Initial alpha value (0 = pure hard labels)
alpha_end: Final alpha value (higher = more self-supervision)
warmup_epochs: Number of epochs before starting SAT
"""
super().__init__()
self.backbone = backbone
self.num_classes = num_classes
self.alpha_start = alpha_start
self.alpha_end = alpha_end
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor:
return self.backbone(x)
[docs]
def get_alpha(self, epoch: int, total_epochs: int) -> float:
"""
Compute current alpha value based on training progress.
Args:
epoch: Current epoch number
total_epochs: Total number of training epochs
Returns:
Alpha value for blending hard and soft labels
"""
if epoch < self.warmup_epochs:
return self.alpha_start
# Linear schedule from alpha_start to alpha_end
progress = (epoch - self.warmup_epochs) / max(
total_epochs - self.warmup_epochs, 1
)
alpha = self.alpha_start + (self.alpha_end - self.alpha_start) * progress
return min(alpha, self.alpha_end)
[docs]
def sat_loss(
self,
logits: torch.Tensor,
targets: torch.Tensor,
alpha: float,
) -> torch.Tensor:
"""
Compute Self-Adaptive Training loss.
Args:
logits: Model predictions (batch_size, num_classes)
targets: Ground truth labels (batch_size,)
alpha: Mixing coefficient for soft labels
Returns:
SAT loss value
"""
# Standard cross-entropy with hard labels
if alpha == 0.0:
return F.cross_entropy(logits, targets)
# Create one-hot encoding of hard labels
hard_labels = F.one_hot(targets, num_classes=self.num_classes).float()
# Get soft labels from current model predictions (detached to avoid gradient flow)
with torch.no_grad():
soft_labels = F.softmax(logits.detach(), dim=1)
# Blend hard and soft labels
adaptive_labels = (1 - alpha) * hard_labels + alpha * soft_labels
# Compute cross-entropy with adaptive labels
log_probs = F.log_softmax(logits, dim=1)
loss = -(adaptive_labels * log_probs).sum(dim=1).mean()
return loss
# ----------------------------------------------------------------------
# FACTORY UTIL
# ----------------------------------------------------------------------
[docs]
def make(selector: str, *args, **kwargs) -> BaseSelectivePredictor:
"""Quick factory: `make('msp', backbone)` or `make('selectivenet', ...)`."""
selector = selector.lower()
if selector in {"msp", "softmax", "threshold"}:
return SoftmaxThreshold(*args, **kwargs)
if selector in {"selectivenet", "sn"}:
return SelectiveNet(*args, **kwargs)
if selector in {"gambler", "deepgambler"}:
return DeepGambler(*args, **kwargs)
if selector in {"sat", "selfadaptive", "self-adaptive"}:
return SelfAdaptiveTraining(*args, **kwargs)
raise ValueError(f"Unknown selector {selector!r}")