"""
Training-time OOD detection methods.
Methods that improve OOD detection during training, rather than
post-hoc after training is complete.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple
[docs]
def mixup_data(
x: torch.Tensor,
y: torch.Tensor,
alpha: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""
Apply mixup data augmentation.
Mixup creates virtual training examples by convex combination:
x_tilde = λ*x_i + (1-λ)*x_j
y_tilde = λ*y_i + (1-λ)*y_j
where λ ~ Beta(α, α). This encourages linear behavior between
examples and improves OOD detection by smoothing decision boundaries.
Reference:
Zhang et al. "mixup: Beyond Empirical Risk Minimization" (ICLR 2018)
Args:
x: Input batch (N, *)
y: Target labels (N,)
alpha: Beta distribution parameter (default: 1.0)
Returns:
Tuple of (mixed_x, y_a, y_b, lambda) where:
- mixed_x: Mixed inputs
- y_a, y_b: Original labels for mixing
- lambda: Mixing coefficient
Example:
>>> mixed_x, y_a, y_b, lam = mixup_data(x, y, alpha=1.0)
>>> outputs = model(mixed_x)
>>> loss = lam * criterion(outputs, y_a) + (1-lam) * criterion(outputs, y_b)
"""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(
criterion: nn.Module,
pred: torch.Tensor,
y_a: torch.Tensor,
y_b: torch.Tensor,
lam: float,
) -> torch.Tensor:
"""
Compute mixup loss.
Linearly interpolates the loss for mixed labels:
Loss = λ * L(pred, y_a) + (1-λ) * L(pred, y_b)
Args:
criterion: Loss function
pred: Model predictions
y_a: First set of labels
y_b: Second set of labels
lam: Mixing coefficient
Returns:
Mixup loss value
Example:
>>> mixed_x, y_a, y_b, lam = mixup_data(x, y)
>>> outputs = model(mixed_x)
>>> loss = mixup_criterion(nn.CrossEntropyLoss(), outputs, y_a, y_b, lam)
"""
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
[docs]
class OutlierExposureLoss(nn.Module):
"""
Outlier Exposure (OE) for improved OOD detection.
Trains model to produce uniform predictions on auxiliary
outlier dataset, improving OOD detection at test time.
Loss = L_CE(f(x_in), y) + λ * L_OE(f(x_out))
where L_OE encourages uniform predictions:
L_OE = -log(1/K) * sum_i p_i = KL(uniform || p)
Reference:
Hendrycks et al. "Deep Anomaly Detection with Outlier Exposure"
(ICLR 2019)
Args:
lambda_oe: Weight for OE loss (default: 0.5)
Example:
>>> criterion = OutlierExposureLoss(lambda_oe=0.5)
>>> loss = criterion(logits_in, targets_in, logits_out)
"""
[docs]
def __init__(self, lambda_oe: float = 0.5):
super().__init__()
self.lambda_oe = lambda_oe
[docs]
def forward(
self,
logits_in: torch.Tensor,
targets_in: torch.Tensor,
logits_out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute OE loss.
Args:
logits_in: In-distribution logits (N_in, C)
targets_in: In-distribution targets (N_in,)
logits_out: Outlier logits (N_out, C), optional
Returns:
Combined loss value
"""
# Standard cross-entropy on in-distribution data
loss_ce = F.cross_entropy(logits_in, targets_in)
if logits_out is None:
return loss_ce
# Outlier exposure: encourage uniform predictions
num_classes = logits_out.size(-1)
probs_out = F.softmax(logits_out, dim=-1)
uniform = torch.ones_like(probs_out) / num_classes
# KL divergence from uniform
loss_oe = F.kl_div(
probs_out.log(),
uniform,
reduction="batchmean",
)
total_loss = loss_ce + self.lambda_oe * loss_oe
return total_loss
class EnergyRegularizedLoss(nn.Module):
"""
Energy-based regularization for OOD detection.
Regularizes energy scores to be lower for in-distribution
and higher for out-of-distribution samples.
Energy(x) = -log sum_i exp(f_i(x))
Reference:
Liu et al. "Energy-based Out-of-distribution Detection" (NeurIPS 2020)
Args:
lambda_energy: Energy regularization weight (default: 0.1)
margin: Energy margin between ID and OOD (default: 10.0)
Example:
>>> criterion = EnergyRegularizedLoss()
>>> loss = criterion(logits_in, targets_in, logits_out)
"""
def __init__(self, lambda_energy: float = 0.1, margin: float = 10.0):
super().__init__()
self.lambda_energy = lambda_energy
self.margin = margin
def forward(
self,
logits_in: torch.Tensor,
targets_in: torch.Tensor,
logits_out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute energy-regularized loss.
Args:
logits_in: In-distribution logits (N_in, C)
targets_in: In-distribution targets (N_in,)
logits_out: Out-of-distribution logits (N_out, C), optional
Returns:
Combined loss value
"""
# Standard cross-entropy
loss_ce = F.cross_entropy(logits_in, targets_in)
if logits_out is None:
return loss_ce
# Energy scores
energy_in = -torch.logsumexp(logits_in, dim=-1)
energy_out = -torch.logsumexp(logits_out, dim=-1)
# Hinge loss: encourage energy_in < energy_out - margin
loss_energy = F.relu(energy_in - energy_out + self.margin).mean()
total_loss = loss_ce + self.lambda_energy * loss_energy
return total_loss
class CutMix:
"""
CutMix augmentation for improved robustness and OOD detection.
Instead of linearly interpolating images (mixup), cuts and pastes
patches between images:
x_tilde = M ⊙ x_i + (1-M) ⊙ x_j
y_tilde = λ*y_i + (1-λ)*y_j
where M is a binary mask and λ is the area ratio.
Reference:
Yun et al. "CutMix: Regularization Strategy to Train Strong
Classifiers with Localizable Features" (ICCV 2019)
Args:
alpha: Beta distribution parameter (default: 1.0)
Example:
>>> cutmix = CutMix(alpha=1.0)
>>> mixed_x, y_a, y_b, lam = cutmix(x, y)
>>> outputs = model(mixed_x)
>>> loss = lam * criterion(outputs, y_a) + (1-lam) * criterion(outputs, y_b)
"""
def __init__(self, alpha: float = 1.0):
self.alpha = alpha
def __call__(
self,
x: torch.Tensor,
y: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""
Apply CutMix augmentation.
Args:
x: Input batch (N, C, H, W)
y: Target labels (N,)
Returns:
Tuple of (mixed_x, y_a, y_b, lambda)
"""
if self.alpha > 0:
lam = np.random.beta(self.alpha, self.alpha)
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
# Generate random bounding box
_, _, H, W = x.size()
cut_ratio = np.sqrt(1.0 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
# Uniform sampling
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
# Apply CutMix
mixed_x = x.clone()
mixed_x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
# Adjust lambda based on actual cut area
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
__all__ = [
"mixup_data",
"mixup_criterion",
"OutlierExposureLoss",
"EnergyRegularizedLoss",
"CutMix",
]