Source code for incerto.ood.methods

# incerto/ood/methods.py
import torch
import torch.nn.functional as F

from .base import OODDetector


class _FeatureHookMixin:
    """Mixin providing a forward-hook mechanism for feature extraction."""

    _hook_handle: object

    def _hook(self, name):
        for n, m in self.model.named_modules():
            if n.endswith(name):
                if self._hook_handle is not None:
                    self._hook_handle.remove()
                self._hook_handle = m.register_forward_hook(
                    lambda _, __, out: setattr(self, "_tmp", out)
                )
                return lambda: self._tmp
        raise ValueError(f"Layer {name} not found")

    def __del__(self):
        if getattr(self, "_hook_handle", None) is not None:
            self._hook_handle.remove()


[docs] class MSP(OODDetector): """Maximum-Softmax-Probability (Hendrycks & Gimpel, 2017)."""
[docs] def score(self, x): logits = self.model(x) return 1 - F.softmax(logits, dim=-1).max(dim=-1).values
def __repr__(self) -> str: return "MSP()"
[docs] class Energy(OODDetector): """Energy-based score (Liu et al., NeurIPS 2020)."""
[docs] def __init__(self, model, temperature=1.0): super().__init__(model) self.temperature = temperature
[docs] def score(self, x): e = -torch.logsumexp(self.model(x) / self.temperature, dim=-1) return e
[docs] def state_dict(self) -> dict: """Save temperature parameter.""" return {"temperature": self.temperature}
[docs] def load_state_dict(self, state: dict) -> None: """Load temperature parameter.""" self.temperature = state["temperature"]
def __repr__(self) -> str: return f"Energy(temperature={self.temperature})"
[docs] class ODIN(OODDetector): """ ODIN: Out-of-Distribution detector for Neural Networks (Liang et al., ICLR 2018). Uses temperature scaling and input perturbation to separate ID/OOD scores. Note: ``score()`` computes input gradients for perturbation and must NOT be called inside a ``torch.no_grad()`` context. Use ``score()`` directly instead of ``predict()`` — the inherited ``predict()`` wraps the call in ``@torch.no_grad()`` which silently disables the perturbation mechanism and degrades detection performance. """
[docs] def __init__(self, model, temperature=1000.0, epsilon=0.0014): super().__init__(model) self.temperature = temperature self.epsilon = epsilon
[docs] def score(self, x): if not torch.is_grad_enabled(): raise RuntimeError( "ODIN.score() requires gradients for input perturbation but " "is running inside a torch.no_grad() context. Call score() " "directly instead of predict()." ) x = x.clone().requires_grad_(True) logits = self.model(x) / self.temperature smax = F.softmax(logits, dim=-1) loss = -smax.max(dim=-1).values.mean() loss.backward() x_adv = x - self.epsilon * x.grad.sign() with torch.no_grad(): logits_adv = self.model(x_adv) / self.temperature return -F.softmax(logits_adv, dim=-1).max(dim=-1).values
[docs] def predict(self, x: torch.Tensor, threshold: float) -> torch.Tensor: """Predict whether inputs are OOD using a threshold. Overrides the base class to avoid ``@torch.no_grad()``, which would break the input-perturbation mechanism. """ return self.score(x) > threshold
[docs] def state_dict(self) -> dict: """Save ODIN hyperparameters.""" return {"temperature": self.temperature, "epsilon": self.epsilon}
[docs] def load_state_dict(self, state: dict) -> None: """Load ODIN hyperparameters.""" self.temperature = state["temperature"] self.epsilon = state["epsilon"]
def __repr__(self) -> str: return f"ODIN(temperature={self.temperature}, epsilon={self.epsilon})"
[docs] class Mahalanobis(_FeatureHookMixin, OODDetector): """Feature-space Mahalanobis (Lee et al., NeurIPS 2018)."""
[docs] def __init__(self, model, layer_name="penultimate"): super().__init__(model) self._layer_name = layer_name self._hook_handle = None self.layer = self._hook(layer_name) self.class_means, self.precision = None, None # filled by `fit`
[docs] def fit(self, loader): acts, labels = [], [] for x, y in loader: self.model(x.to(next(self.model.parameters()).device)) acts.append(self.layer().flatten(1).cpu()) labels.append(y.cpu()) acts = torch.cat(acts) labels = torch.cat(labels) self.class_means = torch.stack( [acts[labels == c].mean(0) for c in torch.unique(labels)] ) cov = torch.cov(acts.T) ridge = max(1e-6, 1e-6 * cov.diag().mean().item()) self.precision = torch.linalg.inv(cov + ridge * torch.eye(cov.size(0)))
[docs] @torch.no_grad() def score(self, x): if self.class_means is None: from ..exceptions import NotFittedError raise NotFittedError("Must call .fit() before .score()") self.model(x) f = self.layer().flatten(1) # Move fitted stats to feature device for GPU compatibility cm = self.class_means.to(f.device) prec = self.precision.to(f.device) diff = f[:, None] - cm # N×C×D d2 = (diff @ prec * diff).sum(-1) # N×C return d2.min(dim=-1).values
[docs] def state_dict(self) -> dict: """Save fitted Mahalanobis parameters.""" return { "class_means": self.class_means, "precision": self.precision, "layer_name": getattr(self, "_layer_name", "penultimate"), }
[docs] def load_state_dict(self, state: dict) -> None: """Load fitted Mahalanobis parameters.""" from ..exceptions import SerializationError try: self.class_means = state["class_means"] self.precision = state["precision"] new_layer = state.get("layer_name", "penultimate") if new_layer != self._layer_name: self._layer_name = new_layer self.layer = self._hook(new_layer) except SerializationError: raise except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e
def __repr__(self) -> str: fitted = self.class_means is not None n_classes = len(self.class_means) if fitted else "not fitted" return f"Mahalanobis(layer='{self._layer_name}', n_classes={n_classes})"
[docs] class MaxLogit(OODDetector): """ MaxLogit OOD detection (Hendrycks et al., 2019). Uses the maximum logit value as the OOD score. Simpler than MSP and often more effective as it doesn't require softmax normalization. """
[docs] def score(self, x): logits = self.model(x) # Return negative max logit (higher logit = more ID-like) return -logits.max(dim=-1).values
def __repr__(self) -> str: return "MaxLogit()"
[docs] class KNN(_FeatureHookMixin, OODDetector): """ KNN-based OOD detection (Sun et al., NeurIPS 2022). Computes OOD score as distance to k-th nearest neighbor in feature space. Requires fitting on training data. """
[docs] def __init__(self, model, k=50, layer_name="penultimate"): super().__init__(model) self.k = k self._layer_name = layer_name self._hook_handle = None self.layer = self._hook(layer_name) self.train_features = None
[docs] def fit(self, loader): """Store training features for KNN computation.""" features = [] for x, _ in loader: self.model(x.to(next(self.model.parameters()).device)) features.append(self.layer().flatten(1).cpu()) self.train_features = torch.cat(features)
[docs] @torch.no_grad() def score(self, x): """Compute OOD score as distance to k-th nearest neighbor.""" if self.train_features is None: from ..exceptions import NotFittedError raise NotFittedError("Must call .fit() before .score()") self.model(x) test_features = self.layer().flatten(1) # Compute pairwise distances (both on same device) train_f = self.train_features.to(test_features.device) dists = torch.cdist(test_features, train_f) # Get k-th nearest neighbor distance # Note: kthvalue not supported on MPS, fall back to CPU if dists.device.type == "mps": kth_dist, _ = torch.kthvalue(dists.cpu(), self.k, dim=-1) return kth_dist.to(test_features.device) kth_dist, _ = torch.kthvalue(dists, self.k, dim=-1) return kth_dist
[docs] def state_dict(self) -> dict: """Save KNN training features.""" return { "k": self.k, "layer_name": getattr(self, "_layer_name", "penultimate"), "train_features": self.train_features, }
[docs] def load_state_dict(self, state: dict) -> None: """Load KNN training features.""" from ..exceptions import SerializationError try: self.k = state["k"] self.train_features = state["train_features"] new_layer = state.get("layer_name", "penultimate") if new_layer != self._layer_name: self._layer_name = new_layer self.layer = self._hook(new_layer) except SerializationError: raise except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e
def __repr__(self) -> str: fitted = self.train_features is not None n_samples = len(self.train_features) if fitted else "not fitted" return f"KNN(k={self.k}, n_train_samples={n_samples})"