Calibration#

The calibration module provides methods for calibrating neural network predictions to ensure that confidence scores accurately reflect the true probability of correctness.

Base Class#

BaseCalibrator()

Abstract base class for all calibration methods.

Post-hoc Calibration Methods#

TemperatureScaling([init_temp])

Temperature scaling for calibration: scales logits by a learned temperature.

VectorScaling(n_classes)

Vector Scaling (Guo et al., 2017).

MatrixScaling(n_classes)

Matrix Scaling (Guo et al., 2017).

PlattScalingCalibrator()

Platt scaling (logistic regression) calibration per class (one-vs-rest).

IsotonicRegressionCalibrator([out_of_bounds])

Multi-class isotonic regression calibration (per-class fitting).

HistogramBinningCalibrator([n_bins])

Histogram binning calibration: bins predicted probabilities and uses empirical frequencies.

DirichletCalibrator(n_classes[, mu])

Dirichlet Calibration (Kull et al., 2019).

BetaCalibrator()

Beta Calibration for binary classification (Kull et al., 2017).

IdentityCalibrator()

No-op calibrator that returns the original softmax probabilities.

Training-time Calibration#

LabelSmoothingLoss([smoothing, reduction])

Label Smoothing for improved calibration.

FocalLoss([alpha, gamma, reduction])

Focal Loss for handling hard examples.

ConfidencePenalty([beta])

Confidence Penalty to prevent overconfidence.

evidential_loss(evidence, targets, ...[, ...])

Evidential Deep Learning loss.

get_uncertainty_from_evidence(evidence, ...)

Compute uncertainty measures from evidential outputs.

TemperatureAwareTraining(backbone[, ...])

Temperature-aware training with learnable temperature.

Metrics#

ece_score(logits, labels[, n_bins])

Expected Calibration Error (ECE).

mce_score(logits, labels[, n_bins])

Maximum Calibration Error (MCE).

classwise_ece(logits, labels[, n_bins])

Class-wise ECE: average ECE computed separately for each class.

adaptive_ece_score(logits, labels[, n_bins, ...])

Adaptive Expected Calibration Error (Nixon et al., 2019).

smooth_ece(logits, labels)

Smooth Expected Calibration Error (smECE).

brier_score(logits, labels)

Brier score: mean squared error between one-hot labels and predicted probabilities.

nll(logits, labels)

Negative Log-Likelihood (cross-entropy) averaged over samples.

Visualization#

plot_reliability_diagram(logits, labels[, ...])

Plot a reliability diagram comparing confidence vs accuracy.

plot_smooth_reliability_diagram(logits, labels)

Plot a smooth reliability diagram using kernel smoothing.

plot_confidence_histogram(logits[, n_bins, ...])

Plot a histogram of model confidences (max softmax probability).

plot_calibration_curve(logits, labels[, ...])

Plot calibration curve: accuracy vs.

Utilities#

get_bin_stats(confidences, accuracies, n_bins)

Helper to compute per-bin average confidence, accuracy, and counts.

extract_confidences_and_predictions(logits)

Extract confidence scores and predictions from logits.

logits_to_probs(logits)

Convert logits to probabilities using softmax.