Source code for incerto.ood.metrics
import torch, numpy as np
from sklearn import metrics as skm
[docs]
def auroc(id_scores, ood_scores):
"""
Compute AUROC for OOD detection.
Args:
id_scores: Scores for in-distribution samples (lower = more ID-like).
ood_scores: Scores for out-of-distribution samples (higher = more OOD-like).
Returns:
AUROC score.
"""
scores = torch.cat([id_scores, ood_scores]).cpu().numpy()
labels = np.concatenate([np.zeros(len(id_scores)), np.ones(len(ood_scores))])
return float(skm.roc_auc_score(labels, scores))
[docs]
def fpr_at_tpr(id_scores, ood_scores, tpr=0.95):
"""
Compute False Positive Rate at a target True Positive Rate.
Args:
id_scores: Scores for in-distribution samples.
ood_scores: Scores for out-of-distribution samples.
tpr: Target true positive rate (default: 0.95).
Returns:
FPR at the target TPR.
"""
scores = torch.cat([id_scores, ood_scores]).cpu().numpy()
labels = np.concatenate([np.zeros(len(id_scores)), np.ones(len(ood_scores))])
fpr, tpr_arr, _ = skm.roc_curve(labels, scores)
return float(np.interp(tpr, tpr_arr, fpr))
def detection_accuracy(id_scores, ood_scores):
"""
Compute detection accuracy at 95% TPR threshold.
Args:
id_scores: Scores for in-distribution samples.
ood_scores: Scores for out-of-distribution samples.
Returns:
Detection accuracy (fraction of correctly classified samples).
"""
thresh = torch.quantile(id_scores, 0.95)
correct = (id_scores <= thresh).sum() + (ood_scores > thresh).sum()
return float(correct.item() / (len(id_scores) + len(ood_scores)))