Source code for incerto.ood.base

"""
Base classes for out-of-distribution (OOD) detection methods.

All OOD detectors should inherit from OODDetector and implement
the score() method for their specific detection strategy.
"""

from abc import ABC, abstractmethod
import torch


[docs] class OODDetector(ABC): """ Abstract base class for out-of-distribution (OOD) detection methods. OOD detectors identify when input data comes from a different distribution than the training data. This is critical for safety in deployed models, as predictions on OOD data are often unreliable. All detectors implement a scoring function where: - **Higher scores → more OOD-like** - **Lower scores → more in-distribution** The model is automatically set to eval mode and gradients are disabled for efficiency. Subclasses must implement score() which defines the OOD scoring method. Example: >>> class MyOODDetector(OODDetector): ... def score(self, x: torch.Tensor) -> torch.Tensor: ... # Your OOD scoring logic ... logits = self.model(x) ... return -torch.max(logits, dim=1).values # Higher = more OOD ... >>> detector = MyOODDetector(model) >>> ood_scores = detector.score(test_data) >>> is_ood = detector.predict(test_data, threshold=0.5) Attributes: model: The neural network in eval mode with gradients disabled. See Also: - Energy: Energy-based OOD detection (Liu et al., 2020) - ODIN: ODIN method with temperature and perturbations - Mahalanobis: Distance-based detection in feature space - MSP: Maximum softmax probability baseline """
[docs] def __init__(self, model): """ Initialize the OOD detector with a trained model. The model is automatically: 1. Set to eval mode 2. Has gradients disabled (requires_grad=False) Args: model: A trained PyTorch model (nn.Module) Raises: TypeError: If model is not an nn.Module. """ if not isinstance(model, torch.nn.Module): raise TypeError(f"model must be an nn.Module, got {type(model).__name__}") self.model = model.eval() for p in self.model.parameters(): p.requires_grad_(False)
[docs] @abstractmethod def score(self, x: torch.Tensor) -> torch.Tensor: """ Compute OOD scores for input samples. Higher scores indicate the input is more likely to be out-of-distribution. Args: x: Input tensor of shape (batch_size, *input_dims) Returns: OOD scores of shape (batch_size,) where higher values indicate more OOD-like samples. Note: The scale of scores depends on the detection method. Use the predict() method with a threshold for binary OOD decisions. """ ...
[docs] @torch.no_grad() def predict(self, x: torch.Tensor, threshold: float) -> torch.Tensor: """ Predict whether inputs are OOD using a threshold. Args: x: Input tensor of shape (batch_size, *input_dims) threshold: Score threshold for OOD classification. Scores > threshold are classified as OOD. Returns: Boolean tensor of shape (batch_size,) where True indicates OOD. Example: >>> is_ood = detector.predict(test_data, threshold=0.5) >>> ood_count = is_ood.sum().item() >>> print(f"Detected {ood_count} OOD samples") """ return self.score(x) > threshold # Bool mask
[docs] def state_dict(self) -> dict: """ Return a dictionary containing the detector's state. Note: The model is NOT saved as part of the state dict. When loading, you must provide the model separately. Returns: Dictionary containing detector-specific parameters and fitted state. """ return {}
[docs] def load_state_dict(self, state: dict) -> None: """ Load detector state from a dictionary. Note: This does not load the model. The model must be set separately via the __init__ method. Args: state: Dictionary containing detector state. Raises: SerializationError: If state is invalid. """ pass
[docs] def save(self, path: str) -> None: """ Save detector state to a file (excluding the model). Args: path: File path where the state will be saved. Raises: SerializationError: If saving fails. Example: >>> detector.fit(train_loader) >>> detector.save('detector_state.pt') """ from ..exceptions import SerializationError try: torch.save(self.state_dict(), path) except Exception as e: raise SerializationError(f"Failed to save to {path}: {e}") from e
[docs] @classmethod def load(cls, path: str, model: torch.nn.Module, **kwargs) -> "OODDetector": """ Load detector state from a file. Args: path: File path to load from. model: A trained PyTorch model to attach to the detector. **kwargs: Additional arguments for the detector constructor. Returns: An OODDetector instance with loaded state. Raises: SerializationError: If loading fails. """ from ..exceptions import SerializationError try: state = torch.load(path, weights_only=True) except Exception as e: raise SerializationError(f"Failed to load from {path}: {e}") from e detector = cls(model, **kwargs) detector.load_state_dict(state) return detector