Source code for incerto.sp.metrics
"""
Core metrics for selective prediction.
Notation:
Y – ground-truth labels
Ŷ – model predictions
R – reject mask; 1 = reject/defer, 0 = predict
"""
from __future__ import annotations
import torch
from typing import Tuple
# ----------------------------------------------------------------------
# BASIC MEASURES ON A BATCH
# ----------------------------------------------------------------------
def coverage(reject: torch.Tensor) -> torch.Tensor:
"""Proportion of *accepted* samples."""
return 1.0 - reject.float().mean()
def risk(pred: torch.Tensor, y: torch.Tensor, reject: torch.Tensor) -> torch.Tensor:
"""Error rate on *accepted* samples (a.k.a. selective risk)."""
correct = (pred == y).float()
accepted = 1.0 - reject.float()
# avoid division by zero with eps
return 1.0 - (correct * accepted).sum() / (accepted.sum() + 1e-9)
[docs]
def aurc(
sorted_conf: torch.Tensor,
sorted_errors: torch.Tensor,
) -> torch.Tensor:
"""
Area under the Risk-Coverage curve.
Inputs must be sorted *descending* by confidence.
"""
n = sorted_conf.numel()
cum_errors = torch.cumsum(sorted_errors, dim=0)
risk_curve = cum_errors / torch.arange(1, n + 1, device=sorted_conf.device)
coverage_curve = torch.arange(1, n + 1, device=sorted_conf.device) / n
# simple trapezoidal rule
return torch.trapz(risk_curve, coverage_curve)
def accuracy_coverage_curve(
logits: torch.Tensor,
y: torch.Tensor,
confidence: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Return `(coverage, accuracy)` curve for each possible threshold.
"""
if confidence is None:
confidence = torch.softmax(logits, dim=-1).max(dim=-1).values
sorted_conf, idx = confidence.sort(descending=True)
sorted_pred = logits.argmax(dim=-1)[idx]
sorted_y = y[idx]
cumul_correct = torch.cumsum((sorted_pred == sorted_y).float(), dim=0)
coverage = torch.arange(1, len(y) + 1, device=y.device) / len(y)
accuracy = cumul_correct / torch.arange(1, len(y) + 1, device=y.device)
return coverage, accuracy