Source code for incerto.calibration.visual

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from .utils import get_bin_stats


[docs] def plot_reliability_diagram( logits: torch.Tensor, labels: torch.Tensor, n_bins: int = 10, ax=None, title: str = "Reliability Diagram", ): """ Plot a reliability diagram comparing confidence vs accuracy. """ probs = F.softmax(logits, dim=1).cpu().detach().numpy() confidences = np.max(probs, axis=1) predictions = np.argmax(probs, axis=1) accuracies = (predictions == labels.cpu().numpy()).astype(float) bin_conf, bin_acc, weight = get_bin_stats(confidences, accuracies, n_bins) if ax is None: fig, ax = plt.subplots() # Perfect calibration ax.plot([0, 1], [0, 1], linestyle="--", label="Perfect") # Empirical ax.plot(bin_conf, bin_acc, marker="o", label="Empirical") # Gap bars centers = (np.arange(n_bins) + 0.5) / n_bins ax.bar( centers, bin_acc - bin_conf, width=1.0 / n_bins, alpha=0.3, edgecolor="black", label="Gap", ) ax.set_xlabel("Confidence") ax.set_ylabel("Accuracy") ax.set_title(title) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.legend() return ax
[docs] def plot_confidence_histogram( logits: torch.Tensor, n_bins: int = 10, ax=None, title: str = "Confidence Histogram", ): """ Plot a histogram of model confidences (max softmax probability). """ probs = F.softmax(logits, dim=1).cpu().detach().numpy() confidences = np.max(probs, axis=1) if ax is None: fig, ax = plt.subplots() ax.hist(confidences, bins=n_bins, range=(0, 1), edgecolor="black") ax.set_xlabel("Confidence") ax.set_ylabel("Count") ax.set_title(title) return ax
[docs] def plot_calibration_curve( logits: torch.Tensor, labels: torch.Tensor, n_bins: int = 10, ax=None, title: str = "Calibration Curve", ): """ Plot calibration curve: accuracy vs. confidence bin centers. """ probs = F.softmax(logits, dim=1).cpu().detach().numpy() confidences = np.max(probs, axis=1) predictions = np.argmax(probs, axis=1) accuracies = (predictions == labels.cpu().numpy()).astype(float) bin_conf, bin_acc, _ = get_bin_stats(confidences, accuracies, n_bins) centers = (np.arange(n_bins) + 0.5) / n_bins if ax is None: fig, ax = plt.subplots() ax.plot(centers, bin_acc, marker="o") ax.set_xlabel("Confidence Bin Center") ax.set_ylabel("Accuracy") ax.set_title(title) ax.set_ylim(0, 1) return ax