Data Utilities#

The data module provides dataset loaders, OOD benchmarks, and data utilities.

Vision Datasets#

VisionDataset([root, val_split, seed])

Base class for vision datasets with standardized splits.

MNIST([root, val_split, seed, normalize])

MNIST dataset with standardized splits.

FashionMNIST([root, val_split, seed, normalize])

Fashion-MNIST dataset with standardized splits.

CIFAR10([root, val_split, seed, normalize, ...])

CIFAR-10 dataset with standardized splits.

CIFAR100([root, val_split, seed, normalize, ...])

CIFAR-100 dataset with standardized splits.

SVHN([root, val_split, seed, normalize])

SVHN (Street View House Numbers) dataset.

OOD Benchmarks#

OODBenchmark([root])

Base class for OOD detection benchmarks.

get_ood_benchmark(name[, root])

Factory function to get OOD benchmark by name.

MNIST_vs_FashionMNIST([root, normalize])

MNIST (ID) vs Fashion-MNIST (OOD) benchmark.

CIFAR10_vs_CIFAR100([root, normalize])

CIFAR-10 (ID) vs CIFAR-100 (OOD) benchmark.

CIFAR10_vs_SVHN([root, normalize])

CIFAR-10 (ID) vs SVHN (OOD) benchmark.

MNIST_vs_NotMNIST([root, normalize])

MNIST (ID) vs NotMNIST (OOD) benchmark.

SubclassOOD(dataset_class[, root, ...])

Create OOD benchmark by holding out specific classes.

CorruptedDataOOD(dataset[, corruption_type, ...])

Create OOD benchmark using corrupted versions of ID data.

Data Loaders#

create_dataloaders(train_dataset[, ...])

Create standard data loaders for train/val/test.

create_balanced_dataloader(dataset[, ...])

Create data loader with balanced class sampling.

create_ood_dataloader(id_dataset, ood_dataset)

Create data loaders for OOD detection evaluation.

create_calibration_loaders(train_dataset, ...)

Create loaders for calibration experiments.

InfiniteDataLoader(dataloader)

Wrapper for infinite data loading.

get_dataloader_stats(dataloader)

Compute statistics about a DataLoader.

Dataset Utilities#

split_dataset(dataset, splits[, seed])

Split dataset into multiple subsets.

filter_dataset_by_class(dataset, classes[, ...])

Filter dataset to only include (or exclude) specific classes.

get_class_balanced_subset(dataset, ...[, seed])

Create a class-balanced subset with equal samples per class.

compute_dataset_statistics(dataset)

Compute statistics about a dataset.

create_imbalanced_dataset(dataset[, ...])

Create an imbalanced version of a dataset.

TransformDataset(dataset, transform)

Wrapper to apply transformations to a dataset.

LabelNoiseDataset(dataset[, noise_rate, ...])

Add label noise to a dataset.

merge_datasets(*datasets)

Merge multiple datasets into one.

subsample_dataset(dataset[, fraction, seed])

Randomly subsample a fraction of the dataset.