Source code for incerto.ood.utils
"""
Utility functions for OOD detection methods.
"""
import torch
import numpy as np
[docs]
def compute_threshold_at_tpr(
id_scores: torch.Tensor | np.ndarray,
target_tpr: float = 0.95,
) -> float:
"""
Compute OOD score threshold that accepts ``target_tpr`` fraction of ID samples.
The threshold is set at the ``target_tpr``-th percentile of the ID score
distribution, so that ``target_tpr`` of in-distribution samples fall below
it (i.e. are correctly classified as ID).
Args:
id_scores: Scores from in-distribution data (lower = more ID-like).
target_tpr: Fraction of ID samples to accept (default: 0.95).
Returns:
Threshold value.
"""
if isinstance(id_scores, torch.Tensor):
id_scores = id_scores.cpu().numpy()
threshold = np.percentile(id_scores, target_tpr * 100)
return float(threshold)
[docs]
def get_ood_predictions(
scores: torch.Tensor | np.ndarray, threshold: float
) -> np.ndarray:
"""
Get binary OOD predictions based on threshold.
Args:
scores: OOD scores (higher = more OOD-like).
threshold: Decision threshold.
Returns:
Binary array (1 = OOD, 0 = ID).
"""
if isinstance(scores, torch.Tensor):
scores = scores.cpu().numpy()
return (scores > threshold).astype(int)
[docs]
def extract_features(
model: torch.nn.Module,
data_loader: torch.utils.data.DataLoader,
layer_name: str = "penultimate",
) -> torch.Tensor:
"""
Extract features from a specific layer of the model.
Uses ``str.endswith`` matching on module names, consistent with the
hook mechanism in :class:`Mahalanobis` and :class:`KNN`.
Args:
model: PyTorch model.
data_loader: DataLoader containing input data.
layer_name: Name (suffix) of layer to extract features from.
Returns:
Tensor of extracted features.
"""
model.eval()
features = []
activation = {}
def get_activation(name):
def hook(module, input, output):
activation[name] = output.detach()
return hook
# Register hook (endswith matching, consistent with _FeatureHookMixin)
handle = None
matched_name = None
for name, module in model.named_modules():
if name.endswith(layer_name):
handle = module.register_forward_hook(get_activation(name))
matched_name = name
break
if handle is None:
raise ValueError(f"Layer '{layer_name}' not found in model")
try:
with torch.no_grad():
for batch in data_loader:
if isinstance(batch, (list, tuple)):
x = batch[0]
else:
x = batch
x = x.to(next(model.parameters()).device)
model(x)
if matched_name in activation:
features.append(activation[matched_name].flatten(1).cpu())
finally:
handle.remove()
return torch.cat(features, dim=0) if features else torch.tensor([])