Source code for incerto.sp.base
"""
Base classes for selective prediction methods.
Selective prediction (also called prediction with rejection) allows models
to abstain from making predictions when uncertain, trading coverage for accuracy.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class BaseSelectivePredictor(nn.Module, ABC):
"""
Abstract base class for selective prediction methods.
Selective predictors enable models to abstain from predictions when uncertain,
optimizing the risk-coverage tradeoff. This is crucial for safety-critical
applications where wrong predictions are costly.
All selective predictors:
1. Generate logits via _forward_logits()
2. Compute confidence scores (default: max softmax probability)
3. Reject low-confidence predictions
Subclasses must implement _forward_logits() which defines how predictions
are made (may include rejection mechanism in architecture).
Example:
>>> class MySelectiveModel(BaseSelectivePredictor):
... def __init__(self, backbone):
... super().__init__()
... self.backbone = backbone
...
... def _forward_logits(self, x):
... return self.backbone(x)
...
>>> model = MySelectiveModel(resnet)
>>> logits, conf = model(x, return_confidence=True)
>>> should_reject = model.reject(conf, threshold=0.9)
>>> # Make predictions only on high-confidence samples
>>> predictions = logits[~should_reject].argmax(dim=-1)
See Also:
- SoftmaxThreshold: Simple confidence thresholding
- SelfAdaptiveTraining: SAT with learned rejection
- DeepGambler: Gambler's loss for selective classification
- SelectiveNet: Auxiliary selection head
"""
[docs]
def forward( # type: ignore[override]
self,
x: torch.Tensor,
*,
return_confidence: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
"""
Forward pass with optional confidence scores.
Args:
x: Input tensor of shape (batch_size, *input_dims)
return_confidence: If True, return (logits, confidence) tuple.
If False, return only logits.
Returns:
If return_confidence=False: logits tensor
If return_confidence=True: (logits, confidence) tuple
Example:
>>> logits = model(x) # Just predictions
>>> logits, conf = model(x, return_confidence=True) # With confidence
"""
logits = self._forward_logits(x)
if return_confidence:
confidence = self.confidence_from_logits(logits)
return logits, confidence
return logits
@abstractmethod
def _forward_logits(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute logits for input samples.
This is the core method that subclasses must implement.
Args:
x: Input tensor of shape (batch_size, *input_dims)
Returns:
Logits tensor of shape (batch_size, num_classes)
"""
...
# ------------------------------------------------------------------
# UTILITIES
# ------------------------------------------------------------------
[docs]
@staticmethod
def confidence_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""
Extract confidence scores from logits.
Default implementation uses maximum softmax probability (MSP).
Subclasses can override for different confidence measures.
Args:
logits: Tensor of shape (batch_size, num_classes)
Returns:
Confidence scores of shape (batch_size,) in range [0, 1]
"""
return F.softmax(logits, dim=-1).max(dim=-1).values
[docs]
@staticmethod
def reject(confidence: torch.Tensor, threshold: float) -> torch.Tensor:
"""
Determine which samples should be rejected based on confidence.
Args:
confidence: Confidence scores of shape (batch_size,)
threshold: Confidence threshold. Samples below this are rejected.
Returns:
Boolean tensor of shape (batch_size,) where True indicates
the sample should be rejected (abstained).
Example:
>>> conf = torch.tensor([0.95, 0.60, 0.85, 0.40])
>>> should_reject = BaseSelectivePredictor.reject(conf, threshold=0.7)
>>> should_reject
tensor([False, True, False, True])
"""
return confidence < threshold