Source code for incerto.conformal.utils

"""
Utility functions for conformal prediction methods.
"""

from __future__ import annotations
import math
import torch


[docs] def compute_quantile( scores: torch.Tensor, alpha: float, adjusted: bool = True ) -> float: """ Compute the conformal quantile at level (1 - alpha). Uses exact order-statistic indexing (ceil-quantile) rather than interpolation, which is required for the finite-sample coverage guarantee to hold. Args: scores: Conformity scores from calibration set. alpha: Desired miscoverage rate. adjusted: If True, uses the finite-sample correction ⌈(1-α)(n+1)⌉. Returns: Quantile threshold. """ n = len(scores) sorted_scores = torch.sort(scores)[0] if adjusted: k = math.ceil((1 - alpha) * (n + 1)) else: k = math.ceil((1 - alpha) * n) k = max(1, min(k, n)) return sorted_scores[k - 1].item()
[docs] def prediction_set_from_scores( scores: torch.Tensor, threshold: float, descending: bool = True ) -> list[torch.Tensor]: """ Convert conformity scores to prediction sets. Args: scores: Score tensor of shape (n_samples, n_classes). threshold: Conformity threshold. descending: If True, higher scores are better. Returns: List of prediction sets (class indices) for each sample. """ if descending: mask = scores >= threshold else: mask = scores <= threshold return [indices.nonzero().squeeze(-1) for indices in mask]
[docs] def split_data( dataset: torch.utils.data.Dataset, cal_ratio: float = 0.5, seed: int | None = None ) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset]: """ Split dataset into calibration and test sets. Args: dataset: PyTorch dataset to split. cal_ratio: Fraction of data to use for calibration. seed: Random seed for reproducibility. Returns: Tuple of (calibration_dataset, test_dataset). """ if seed is not None: torch.manual_seed(seed) n = len(dataset) indices = torch.randperm(n) n_cal = int(n * cal_ratio) cal_indices = indices[:n_cal] test_indices = indices[n_cal:] cal_set = torch.utils.data.Subset(dataset, cal_indices.tolist()) test_set = torch.utils.data.Subset(dataset, test_indices.tolist()) return cal_set, test_set