"""
Standard OOD detection benchmarks.
Provides utilities for creating common OOD detection benchmark scenarios
used in uncertainty quantification research.
"""
from __future__ import annotations
import torch
from torch.utils.data import Dataset, Subset
try:
from torchvision import datasets, transforms
except ImportError:
raise ImportError(
"torchvision is required for incerto.data.ood_benchmarks. "
"Install it with: pip install incerto[vision]"
)
from typing import Tuple, List
from pathlib import Path
import numpy as np
[docs]
class OODBenchmark:
"""
Base class for OOD detection benchmarks.
Provides in-distribution (ID) and out-of-distribution (OOD) datasets
for evaluation.
"""
[docs]
def __init__(self, root: str | Path = "./data"):
self.root = Path(root)
self.root.mkdir(parents=True, exist_ok=True)
[docs]
def get_datasets(self) -> Tuple[Dataset, Dataset]:
"""
Get ID and OOD datasets.
Returns:
Tuple of (id_dataset, ood_dataset)
"""
raise NotImplementedError
[docs]
class MNIST_vs_FashionMNIST(OODBenchmark):
"""
MNIST (ID) vs Fashion-MNIST (OOD) benchmark.
A common near-OOD benchmark where both datasets have similar
structure but different semantics.
"""
[docs]
def __init__(
self,
root: str | Path = "./data",
normalize: bool = True,
):
super().__init__(root)
self.normalize = normalize
[docs]
def get_datasets(self) -> Tuple[Dataset, Dataset]:
"""Get MNIST (ID) and Fashion-MNIST (OOD) test sets."""
transform = self.get_transforms()
id_dataset = datasets.MNIST(
root=self.root,
train=False,
download=True,
transform=transform,
)
ood_dataset = datasets.FashionMNIST(
root=self.root,
train=False,
download=True,
transform=transform,
)
return id_dataset, ood_dataset
[docs]
class CIFAR10_vs_CIFAR100(OODBenchmark):
"""
CIFAR-10 (ID) vs CIFAR-100 (OOD) benchmark.
Near-OOD benchmark with similar image statistics but different classes.
"""
[docs]
def __init__(
self,
root: str | Path = "./data",
normalize: bool = True,
):
super().__init__(root)
self.normalize = normalize
[docs]
def get_datasets(self) -> Tuple[Dataset, Dataset]:
"""Get CIFAR-10 (ID) and CIFAR-100 (OOD) test sets."""
transform = self.get_transforms()
id_dataset = datasets.CIFAR10(
root=self.root,
train=False,
download=True,
transform=transform,
)
ood_dataset = datasets.CIFAR100(
root=self.root,
train=False,
download=True,
transform=transform,
)
return id_dataset, ood_dataset
[docs]
class CIFAR10_vs_SVHN(OODBenchmark):
"""
CIFAR-10 (ID) vs SVHN (OOD) benchmark.
Far-OOD benchmark with different image statistics.
"""
[docs]
def __init__(
self,
root: str | Path = "./data",
normalize: bool = True,
):
super().__init__(root)
self.normalize = normalize
[docs]
def get_datasets(self) -> Tuple[Dataset, Dataset]:
"""Get CIFAR-10 (ID) and SVHN (OOD) test sets."""
transform = self.get_transforms()
id_dataset = datasets.CIFAR10(
root=self.root,
train=False,
download=True,
transform=transform,
)
ood_dataset = datasets.SVHN(
root=self.root,
split="test",
download=True,
transform=transform,
)
return id_dataset, ood_dataset
[docs]
class MNIST_vs_NotMNIST(OODBenchmark):
"""
MNIST (ID) vs NotMNIST (OOD) benchmark.
NotMNIST contains letters instead of digits.
Note: NotMNIST requires manual download.
"""
[docs]
def __init__(
self,
root: str | Path = "./data",
normalize: bool = True,
):
super().__init__(root)
self.normalize = normalize
[docs]
def get_datasets(self) -> Tuple[Dataset, Dataset]:
"""Get MNIST (ID) and NotMNIST (OOD) test sets."""
id_transform = self.get_transforms()
id_dataset = datasets.MNIST(
root=self.root,
train=False,
download=True,
transform=id_transform,
)
# NotMNIST loaded via ImageFolder produces RGB; convert to grayscale
# and resize to match MNIST dimensions
ood_transform_list = [
transforms.Grayscale(num_output_channels=1),
transforms.Resize(28),
transforms.ToTensor(),
]
if self.normalize:
ood_transform_list.append(transforms.Normalize((0.1307,), (0.3081,)))
ood_transform = transforms.Compose(ood_transform_list)
notmnist_path = self.root / "notMNIST"
if notmnist_path.exists():
ood_dataset = datasets.ImageFolder(
root=notmnist_path,
transform=ood_transform,
)
else:
raise FileNotFoundError(
f"NotMNIST not found at {notmnist_path}. "
"Please download NotMNIST dataset manually."
)
return id_dataset, ood_dataset
[docs]
class SubclassOOD(OODBenchmark):
"""
Create OOD benchmark by holding out specific classes.
Example: Train on classes 0-7 of CIFAR-10, test with classes 8-9 as OOD.
"""
[docs]
def __init__(
self,
dataset_class,
root: str | Path = "./data",
id_classes: List[int] = None,
ood_classes: List[int] = None,
transform: transforms.Compose = None,
**dataset_kwargs,
):
"""
Initialize subclass OOD benchmark.
Args:
dataset_class: Dataset class (e.g., datasets.CIFAR10)
root: Data root directory
id_classes: List of class indices to use as ID
ood_classes: List of class indices to use as OOD
transform: Transform to apply (defaults to ToTensor())
**dataset_kwargs: Additional arguments for dataset
"""
super().__init__(root)
self.dataset_class = dataset_class
self.id_classes = id_classes
self.ood_classes = ood_classes
self.transform = transform if transform is not None else transforms.ToTensor()
self.dataset_kwargs = dataset_kwargs
[docs]
def get_datasets(self) -> Tuple[Dataset, Dataset]:
"""Get ID and OOD subsets."""
if self.id_classes is None and self.ood_classes is None:
raise ValueError(
"At least one of id_classes or ood_classes must be specified"
)
# Load full dataset; default to test split if not specified
kwargs = dict(root=self.root, download=True, transform=self.transform)
if "train" not in self.dataset_kwargs and "split" not in self.dataset_kwargs:
kwargs["train"] = False
kwargs.update(self.dataset_kwargs)
full_dataset = self.dataset_class(**kwargs)
# Get targets
if hasattr(full_dataset, "targets"):
targets = np.array(full_dataset.targets)
elif hasattr(full_dataset, "labels"):
targets = np.array(full_dataset.labels)
else:
raise ValueError("Dataset must have 'targets' or 'labels' attribute")
# Auto-complement: if only one side specified, the other gets the rest
all_classes = set(np.unique(targets).tolist())
id_classes = self.id_classes
ood_classes = self.ood_classes
if id_classes is not None and ood_classes is None:
ood_classes = sorted(all_classes - set(id_classes))
elif ood_classes is not None and id_classes is None:
id_classes = sorted(all_classes - set(ood_classes))
# Find ID and OOD indices
id_set = set(id_classes)
ood_set = set(ood_classes)
id_indices = [idx for idx, t in enumerate(targets) if t in id_set]
ood_indices = [idx for idx, t in enumerate(targets) if t in ood_set]
# Create subsets
id_dataset = Subset(full_dataset, id_indices)
ood_dataset = Subset(full_dataset, ood_indices)
return id_dataset, ood_dataset
[docs]
class CorruptedDataOOD(OODBenchmark):
"""
Create OOD benchmark using corrupted versions of ID data.
Applies various corruptions (noise, blur, etc.) to create OOD samples.
"""
[docs]
def __init__(
self,
dataset: Dataset,
corruption_type: str = "gaussian_noise",
severity: float = 0.5,
):
"""
Initialize corrupted data OOD benchmark.
Args:
dataset: Base dataset
corruption_type: Type of corruption to apply
severity: Corruption severity (0-1)
"""
# CorruptedDataOOD does not need a root directory; skip OODBenchmark.__init__
self.root = None
self.dataset = dataset
self.corruption_type = corruption_type
self.severity = severity
[docs]
def apply_corruption(self, image: torch.Tensor) -> torch.Tensor:
"""Apply corruption to image."""
if self.corruption_type == "gaussian_noise":
noise = torch.randn_like(image) * self.severity
return image + noise
elif self.corruption_type == "salt_pepper":
mask = torch.rand_like(image)
corrupted = image.clone()
corrupted[mask < self.severity / 2] = 0.0
corrupted[mask > 1 - self.severity / 2] = 1.0
return corrupted
elif self.corruption_type == "blur":
# Simple box blur via average pooling
kernel_size = int(self.severity * 10) + 1
if kernel_size % 2 == 0:
kernel_size += 1
padding = kernel_size // 2
# (C, H, W) -> (1, C, H, W) for avg_pool2d
blurred = torch.nn.functional.avg_pool2d(
image.unsqueeze(0), kernel_size, stride=1, padding=padding
).squeeze(0)
return blurred
else:
raise ValueError(f"Unknown corruption type: {self.corruption_type}")
[docs]
def get_datasets(self) -> Tuple[Dataset, Dataset]:
"""Get ID (clean) and OOD (corrupted) datasets."""
# ID is the original dataset
id_dataset = self.dataset
# OOD is corrupted version
class CorruptedDataset(Dataset):
def __init__(self, base_dataset, corruption_fn):
self.base_dataset = base_dataset
self.corruption_fn = corruption_fn
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, idx):
image, label = self.base_dataset[idx]
corrupted_image = self.corruption_fn(image)
return corrupted_image, label
ood_dataset = CorruptedDataset(self.dataset, self.apply_corruption)
return id_dataset, ood_dataset
[docs]
def get_ood_benchmark(
name: str,
root: str | Path = "./data",
**kwargs,
) -> OODBenchmark:
"""
Factory function to get OOD benchmark by name.
Args:
name: Benchmark name
root: Data root directory
**kwargs: Additional arguments
Returns:
OODBenchmark instance
Available benchmarks:
- "mnist_vs_fmnist": MNIST vs Fashion-MNIST
- "cifar10_vs_cifar100": CIFAR-10 vs CIFAR-100
- "cifar10_vs_svhn": CIFAR-10 vs SVHN
- "mnist_vs_notmnist": MNIST vs NotMNIST
"""
benchmarks = {
"mnist_vs_fmnist": MNIST_vs_FashionMNIST,
"cifar10_vs_cifar100": CIFAR10_vs_CIFAR100,
"cifar10_vs_svhn": CIFAR10_vs_SVHN,
"mnist_vs_notmnist": MNIST_vs_NotMNIST,
}
if name not in benchmarks:
raise ValueError(
f"Unknown benchmark: {name}. " f"Available: {list(benchmarks.keys())}"
)
return benchmarks[name](root=root, **kwargs)