Source code for incerto.ood.visual
import matplotlib.pyplot as plt
import torch
import numpy as np
[docs]
def plot_roc(id_scores, ood_scores, label=None, ax=None):
"""
Plot ROC curve for OOD detection.
Args:
id_scores: Scores for in-distribution samples.
ood_scores: Scores for out-of-distribution samples.
label: Label for the plot.
ax: Matplotlib axes object.
"""
from sklearn.metrics import RocCurveDisplay
scores = torch.cat([id_scores, ood_scores]).cpu().numpy()
labels = np.concatenate([np.zeros(len(id_scores)), np.ones(len(ood_scores))])
RocCurveDisplay.from_predictions(labels, scores, ax=ax, name=label)
if ax is None:
ax = plt.gca()
ax.set_aspect("equal", adjustable="box")
return ax
[docs]
def score_hist(id_scores, ood_scores, ax=None, bins=50):
"""
Plot histogram of OOD scores for ID and OOD samples.
Args:
id_scores: Scores for in-distribution samples.
ood_scores: Scores for out-of-distribution samples.
ax: Matplotlib axes object.
bins: Number of histogram bins.
Returns:
Matplotlib axes object.
"""
ax = ax or plt.gca()
ax.hist(id_scores.cpu(), bins=bins, alpha=0.6, label="ID")
ax.hist(ood_scores.cpu(), bins=bins, alpha=0.6, label="OOD")
ax.set_xlabel("OOD score")
ax.set_ylabel("# samples")
ax.legend()
return ax