incerto.ood.OODDetector#

class incerto.ood.OODDetector(model)[source]#

Bases: ABC

Abstract base class for out-of-distribution (OOD) detection methods.

OOD detectors identify when input data comes from a different distribution than the training data. This is critical for safety in deployed models, as predictions on OOD data are often unreliable.

All detectors implement a scoring function where: - Higher scores → more OOD-like - Lower scores → more in-distribution

The model is automatically set to eval mode and gradients are disabled for efficiency.

Subclasses must implement score() which defines the OOD scoring method.

Example

>>> class MyOODDetector(OODDetector):
...     def score(self, x: torch.Tensor) -> torch.Tensor:
...         # Your OOD scoring logic
...         logits = self.model(x)
...         return -torch.max(logits, dim=1).values  # Higher = more OOD
...
>>> detector = MyOODDetector(model)
>>> ood_scores = detector.score(test_data)
>>> is_ood = detector.predict(test_data, threshold=0.5)
model#

The neural network in eval mode with gradients disabled.

See also

  • Energy: Energy-based OOD detection (Liu et al., 2020)

  • ODIN: ODIN method with temperature and perturbations

  • Mahalanobis: Distance-based detection in feature space

  • MSP: Maximum softmax probability baseline

__init__(model)[source]#

Initialize the OOD detector with a trained model.

The model is automatically: 1. Set to eval mode 2. Has gradients disabled (requires_grad=False)

Parameters:

model – A trained PyTorch model (nn.Module)

Raises:

TypeError – If model is not an nn.Module.

Methods

__init__(model)

Initialize the OOD detector with a trained model.

load(path, model, **kwargs)

Load detector state from a file.

load_state_dict(state)

Load detector state from a dictionary.

predict(x, threshold)

Predict whether inputs are OOD using a threshold.

save(path)

Save detector state to a file (excluding the model).

score(x)

Compute OOD scores for input samples.

state_dict()

Return a dictionary containing the detector's state.

__init__(model)[source]#

Initialize the OOD detector with a trained model.

The model is automatically: 1. Set to eval mode 2. Has gradients disabled (requires_grad=False)

Parameters:

model – A trained PyTorch model (nn.Module)

Raises:

TypeError – If model is not an nn.Module.

abstract score(x)[source]#

Compute OOD scores for input samples.

Higher scores indicate the input is more likely to be out-of-distribution.

Parameters:

x (Tensor) – Input tensor of shape (batch_size, *input_dims)

Return type:

Tensor

Returns:

OOD scores of shape (batch_size,) where higher values indicate more OOD-like samples.

Note

The scale of scores depends on the detection method. Use the predict() method with a threshold for binary OOD decisions.

predict(x, threshold)[source]#

Predict whether inputs are OOD using a threshold.

Parameters:
  • x (Tensor) – Input tensor of shape (batch_size, *input_dims)

  • threshold (float) – Score threshold for OOD classification. Scores > threshold are classified as OOD.

Return type:

Tensor

Returns:

Boolean tensor of shape (batch_size,) where True indicates OOD.

Example

>>> is_ood = detector.predict(test_data, threshold=0.5)
>>> ood_count = is_ood.sum().item()
>>> print(f"Detected {ood_count} OOD samples")
state_dict()[source]#

Return a dictionary containing the detector’s state.

Note

The model is NOT saved as part of the state dict. When loading, you must provide the model separately.

Return type:

dict

Returns:

Dictionary containing detector-specific parameters and fitted state.

load_state_dict(state)[source]#

Load detector state from a dictionary.

Note

This does not load the model. The model must be set separately via the __init__ method.

Parameters:

state (dict) – Dictionary containing detector state.

Raises:

SerializationError – If state is invalid.

Return type:

None

save(path)[source]#

Save detector state to a file (excluding the model).

Parameters:

path (str) – File path where the state will be saved.

Raises:

SerializationError – If saving fails.

Return type:

None

Example

>>> detector.fit(train_loader)
>>> detector.save('detector_state.pt')
classmethod load(path, model, **kwargs)[source]#

Load detector state from a file.

Parameters:
  • path (str) – File path to load from.

  • model (Module) – A trained PyTorch model to attach to the detector.

  • **kwargs – Additional arguments for the detector constructor.

Return type:

OODDetector

Returns:

An OODDetector instance with loaded state.

Raises:

SerializationError – If loading fails.