Source code for incerto.data.loaders

"""
Data loading utilities for uncertainty quantification.

Provides standardized data loaders and sampling strategies.
"""

from __future__ import annotations
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from typing import Optional, Tuple
import numpy as np
from .utils import _get_targets


[docs] def create_dataloaders( train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, batch_size: int = 128, num_workers: int = 4, pin_memory: bool = True, shuffle_train: bool = True, ) -> Tuple[DataLoader, ...]: """ Create standard data loaders for train/val/test. Args: train_dataset: Training dataset val_dataset: Validation dataset (optional) test_dataset: Test dataset (optional) batch_size: Batch size num_workers: Number of data loading workers pin_memory: Whether to pin memory shuffle_train: Whether to shuffle training data Returns: Tuple of DataLoaders (train, val, test) If val or test is None, returns None for that loader """ train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory, ) val_loader = None if val_dataset is not None: val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) test_loader = None if test_dataset is not None: test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) return train_loader, val_loader, test_loader
[docs] def create_balanced_dataloader( dataset: Dataset, batch_size: int = 128, num_workers: int = 4, pin_memory: bool = True, ) -> DataLoader: """ Create data loader with balanced class sampling. Ensures equal representation of all classes in each batch. Args: dataset: Dataset with 'targets' or 'labels' attribute batch_size: Batch size num_workers: Number of workers pin_memory: Whether to pin memory Returns: DataLoader with balanced sampling """ targets = _get_targets(dataset) # Compute class weights (use np.unique to avoid division by zero with non-contiguous classes) unique_classes, class_counts = np.unique(targets, return_counts=True) class_weight_map = dict(zip(unique_classes, 1.0 / class_counts)) sample_weights = np.array([class_weight_map[t] for t in targets]) # Create sampler sampler = WeightedRandomSampler( weights=sample_weights, num_samples=len(sample_weights), replacement=True, ) # Create loader loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, ) return loader
[docs] def create_ood_dataloader( id_dataset: Dataset, ood_dataset: Dataset, batch_size: int = 128, num_workers: int = 4, pin_memory: bool = True, mixed: bool = False, mix_ratio: float = 0.5, ) -> Tuple[DataLoader, DataLoader] | DataLoader: """ Create data loaders for OOD detection evaluation. Args: id_dataset: In-distribution dataset ood_dataset: Out-of-distribution dataset batch_size: Batch size num_workers: Number of workers pin_memory: Whether to pin memory mixed: If True, create single mixed loader; if False, separate loaders mix_ratio: Ratio of ID samples in mixed loader (only if mixed=True) Returns: If mixed=False: Tuple of (id_loader, ood_loader) If mixed=True: Single mixed loader """ if not mixed: # Separate loaders id_loader = DataLoader( id_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) ood_loader = DataLoader( ood_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) return id_loader, ood_loader else: # Mixed loader from torch.utils.data import ConcatDataset # Combine datasets combined = ConcatDataset([id_dataset, ood_dataset]) # Create weights for mixing n_id = len(id_dataset) n_ood = len(ood_dataset) # Weight to achieve desired ratio id_weight = mix_ratio / n_id ood_weight = (1 - mix_ratio) / n_ood weights = [id_weight] * n_id + [ood_weight] * n_ood sampler = WeightedRandomSampler( weights=weights, num_samples=len(weights), replacement=True, ) mixed_loader = DataLoader( combined, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, ) return mixed_loader
[docs] class InfiniteDataLoader: """ Wrapper for infinite data loading. Useful for training with unequal dataset sizes or continuous sampling. """
[docs] def __init__(self, dataloader: DataLoader): """ Initialize infinite data loader. Args: dataloader: Base DataLoader to wrap """ self.dataloader = dataloader self.iterator = iter(self.dataloader)
def __iter__(self): return self def __next__(self): try: batch = next(self.iterator) except StopIteration: # Restart iterator self.iterator = iter(self.dataloader) batch = next(self.iterator) return batch def __len__(self): return len(self.dataloader)
[docs] def create_calibration_loaders( train_dataset: Dataset, test_dataset: Dataset, calib_split: float = 0.5, batch_size: int = 128, num_workers: int = 4, pin_memory: bool = True, seed: int = 42, ) -> Tuple[DataLoader, DataLoader, DataLoader]: """ Create loaders for calibration experiments. Splits training data into train/calibration sets, plus test set. Args: train_dataset: Training dataset test_dataset: Test dataset calib_split: Fraction of training data to use for calibration batch_size: Batch size num_workers: Number of workers pin_memory: Whether to pin memory seed: Random seed Returns: Tuple of (train_loader, calib_loader, test_loader) """ from torch.utils.data import Subset # Split training data n_train = len(train_dataset) n_calib = int(n_train * calib_split) n_train_actual = n_train - n_calib # Create reproducible split generator = torch.Generator().manual_seed(seed) indices = torch.randperm(n_train, generator=generator).tolist() train_indices = indices[:n_train_actual] calib_indices = indices[n_train_actual:] train_subset = Subset(train_dataset, train_indices) calib_subset = Subset(train_dataset, calib_indices) # Create loaders train_loader = DataLoader( train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, ) calib_loader = DataLoader( calib_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) return train_loader, calib_loader, test_loader
[docs] def get_dataloader_stats(dataloader: DataLoader) -> dict: """ Compute statistics about a DataLoader. Args: dataloader: DataLoader to analyze Returns: Dictionary with statistics """ stats = { "num_batches": len(dataloader), "batch_size": dataloader.batch_size, "num_workers": dataloader.num_workers, "pin_memory": dataloader.pin_memory, } # Try to get dataset size if hasattr(dataloader.dataset, "__len__"): stats["dataset_size"] = len(dataloader.dataset) # Try to get class distribution try: targets = _get_targets(dataloader.dataset) unique, counts = np.unique(targets, return_counts=True) stats["num_classes"] = len(unique) stats["class_distribution"] = dict(zip(unique.tolist(), counts.tolist())) except ValueError: pass # Dataset has no targets/labels attribute return stats