# incerto/ood/methods.py
import torch
import torch.nn.functional as F
from .base import OODDetector
[docs]
class MSP(OODDetector):
"""Maximum-Softmax-Probability (Hendrycks & Gimpel, 2017)."""
[docs]
def score(self, x):
logits = self.model(x)
return 1 - F.softmax(logits, dim=-1).max(dim=-1).values
def __repr__(self) -> str:
return "MSP()"
[docs]
class Energy(OODDetector):
"""Energy-based score (Liu et al., NeurIPS 2020)."""
[docs]
def __init__(self, model, temperature=1.0):
super().__init__(model)
self.T = temperature
[docs]
def score(self, x):
e = -torch.logsumexp(self.model(x) / self.T, dim=-1)
return e
[docs]
def state_dict(self) -> dict:
"""Save temperature parameter."""
return {"temperature": self.T}
[docs]
def load_state_dict(self, state: dict) -> None:
"""Load temperature parameter."""
self.T = state["temperature"]
def __repr__(self) -> str:
return f"Energy(temperature={self.T})"
[docs]
class ODIN(OODDetector):
"""ODIN (Liang et al., ICLR 2018)."""
[docs]
def __init__(self, model, temperature=1000.0, epsilon=0.0014):
super().__init__(model)
self.T, self.eps = temperature, epsilon
[docs]
def score(self, x):
x = x.clone().requires_grad_(True)
logits = self.model(x) / self.T
smax = F.softmax(logits, dim=-1)
loss = -smax.max(dim=-1).values.mean()
loss.backward()
x_adv = x + self.eps * x.grad.sign()
logits_adv = self.model(x_adv) / self.T
return -F.softmax(logits_adv, dim=-1).max(dim=-1).values
[docs]
def state_dict(self) -> dict:
"""Save ODIN hyperparameters."""
return {"temperature": self.T, "epsilon": self.eps}
[docs]
def load_state_dict(self, state: dict) -> None:
"""Load ODIN hyperparameters."""
self.T = state["temperature"]
self.eps = state["epsilon"]
def __repr__(self) -> str:
return f"ODIN(temperature={self.T}, epsilon={self.eps})"
[docs]
class Mahalanobis(OODDetector):
"""Feature-space Mahalanobis (Lee et al., NeurIPS 2018)."""
[docs]
def __init__(self, model, layer_name="penultimate"):
super().__init__(model)
self.layer = self._hook(layer_name)
self.class_means, self.precision = None, None # filled by `fit`
[docs]
def fit(self, loader):
acts, labels = [], []
for x, y in loader:
self.model(x.to(next(self.model.parameters()).device))
acts.append(self.layer().flatten(1).cpu())
labels.append(y.cpu())
acts = torch.cat(acts)
labels = torch.cat(labels)
self.class_means = torch.stack(
[acts[labels == c].mean(0) for c in torch.unique(labels)]
)
cov = torch.cov(acts.T)
self.precision = torch.linalg.inv(cov + 1e-6 * torch.eye(cov.size(0)))
[docs]
def score(self, x):
self.model(x)
f = self.layer().flatten(1)
d2 = (
(f[:, None] - self.class_means) # N×C×D
@ self.precision
* (f[:, None] - self.class_means)
).sum(
-1
) # N×C
return d2.min(dim=-1).values
def _hook(self, name):
for n, m in self.model.named_modules():
if n.endswith(name):
handle = m.register_forward_hook(
lambda _, __, out: setattr(self, "_tmp", out)
)
return lambda: self._tmp
raise ValueError(f"Layer {name} not found")
[docs]
def state_dict(self) -> dict:
"""Save fitted Mahalanobis parameters."""
return {
"class_means": self.class_means,
"precision": self.precision,
"layer_name": getattr(self, "_layer_name", "penultimate"),
}
[docs]
def load_state_dict(self, state: dict) -> None:
"""Load fitted Mahalanobis parameters."""
from ..exceptions import SerializationError
try:
self.class_means = state["class_means"]
self.precision = state["precision"]
self._layer_name = state.get("layer_name", "penultimate")
except Exception as e:
raise SerializationError(f"Failed to load state: {e}") from e
def __repr__(self) -> str:
fitted = self.class_means is not None
n_classes = len(self.class_means) if fitted else "not fitted"
return f"Mahalanobis(layer='penultimate', n_classes={n_classes})"
[docs]
class MaxLogit(OODDetector):
"""
MaxLogit OOD detection (Hendrycks et al., 2019).
Uses the maximum logit value as the OOD score. Simpler than MSP
and often more effective as it doesn't require softmax normalization.
"""
[docs]
def score(self, x):
logits = self.model(x)
# Return negative max logit (higher logit = more ID-like)
return -logits.max(dim=-1).values
def __repr__(self) -> str:
return "MaxLogit()"
[docs]
class KNN(OODDetector):
"""
KNN-based OOD detection (Sun et al., NeurIPS 2022).
Computes OOD score as distance to k-th nearest neighbor in feature space.
Requires fitting on training data.
"""
[docs]
def __init__(self, model, k=50, layer_name="penultimate"):
super().__init__(model)
self.k = k
self.layer = self._hook(layer_name)
self.train_features = None
[docs]
def fit(self, loader):
"""Store training features for KNN computation."""
features = []
for x, _ in loader:
self.model(x.to(next(self.model.parameters()).device))
features.append(self.layer().flatten(1).cpu())
self.train_features = torch.cat(features)
[docs]
def score(self, x):
"""Compute OOD score as distance to k-th nearest neighbor."""
if self.train_features is None:
raise RuntimeError("Must call .fit() before .score()")
self.model(x)
test_features = self.layer().flatten(1).cpu()
# Compute pairwise distances
dists = torch.cdist(test_features, self.train_features)
# Get k-th nearest neighbor distance
kth_dist, _ = torch.kthvalue(dists, self.k, dim=-1)
return kth_dist.to(x.device)
def _hook(self, name):
for n, m in self.model.named_modules():
if n.endswith(name):
handle = m.register_forward_hook(
lambda _, __, out: setattr(self, "_tmp", out)
)
return lambda: self._tmp
raise ValueError(f"Layer {name} not found")
[docs]
def state_dict(self) -> dict:
"""Save KNN training features."""
return {
"k": self.k,
"layer_name": getattr(self, "_layer_name", "penultimate"),
"train_features": self.train_features,
}
[docs]
def load_state_dict(self, state: dict) -> None:
"""Load KNN training features."""
from ..exceptions import SerializationError
try:
self.k = state["k"]
self._layer_name = state.get("layer_name", "penultimate")
self.train_features = state["train_features"]
except Exception as e:
raise SerializationError(f"Failed to load state: {e}") from e
def __repr__(self) -> str:
fitted = self.train_features is not None
n_samples = len(self.train_features) if fitted else "not fitted"
return f"KNN(k={self.k}, n_train_samples={n_samples})"