Source code for incerto.bayesian.metrics

"""
Metrics for evaluating Bayesian deep learning models.
"""

from __future__ import annotations
import torch
import numpy as np
from typing import Tuple


def ensemble_diversity(
    predictions: torch.Tensor,
    metric: str = "variance",
) -> float:
    """
    Measure diversity among ensemble members.

    Args:
        predictions: Tensor of shape (num_models, batch_size, num_classes)
        metric: Diversity metric ('variance', 'disagreement', 'kl')

    Returns:
        Diversity score (higher = more diverse)
    """
    if metric == "variance":
        # Average variance across predictions
        variance = predictions.var(dim=0)
        return variance.mean().item()

    elif metric == "disagreement":
        # Fraction of samples where models disagree
        pred_labels = predictions.argmax(dim=-1)  # (num_models, batch_size)
        # Check if all models agree
        mode = torch.mode(pred_labels, dim=0).values
        agreement = (pred_labels == mode.unsqueeze(0)).all(dim=0)
        disagreement_rate = 1.0 - agreement.float().mean()
        return disagreement_rate.item()

    elif metric == "kl":
        # Average KL divergence between pairs
        num_models = predictions.size(0)
        kl_sum = 0.0
        count = 0

        for i in range(num_models):
            for j in range(i + 1, num_models):
                kl = (
                    (
                        predictions[i]
                        * (
                            torch.log(predictions[i] + 1e-10)
                            - torch.log(predictions[j] + 1e-10)
                        )
                    )
                    .sum(dim=-1)
                    .mean()
                )
                kl_sum += kl.item()
                count += 1

        return kl_sum / count if count > 0 else 0.0

    else:
        raise ValueError(f"Unknown metric: {metric}")


def uncertainty_quality(
    uncertainties: torch.Tensor,
    errors: torch.Tensor,
) -> Tuple[float, float]:
    """
    Evaluate quality of uncertainty estimates.

    Measures correlation between uncertainty and prediction error.
    Good uncertainty should be high when the model is wrong.

    Args:
        uncertainties: Predicted uncertainties (batch_size,)
        errors: Binary error indicators (batch_size,) where 1=error, 0=correct

    Returns:
        Tuple of (spearman_correlation, auroc)
    """
    # Spearman correlation between uncertainty and error
    uncertainties_np = uncertainties.detach().cpu().numpy()
    errors_np = errors.detach().cpu().numpy()

    # Rank correlation
    from scipy.stats import spearmanr

    correlation, _ = spearmanr(uncertainties_np, errors_np)

    # AUROC for using uncertainty to detect errors
    from sklearn.metrics import roc_auc_score

    try:
        auroc = roc_auc_score(errors_np, uncertainties_np)
    except ValueError:
        # All errors are the same (all correct or all wrong)
        auroc = 0.5

    return float(correlation), float(auroc)


[docs] def disagreement( predictions: torch.Tensor, method: str = "variance", ) -> torch.Tensor: """ Compute disagreement score for each sample. Args: predictions: Tensor of shape (num_models, batch_size, num_classes) method: Disagreement method ('variance', 'entropy') Returns: Disagreement scores of shape (batch_size,) """ if method == "variance": # Variance of predictions variance = predictions.var(dim=0) return variance.mean(dim=-1) elif method == "entropy": # Entropy of the mean prediction mean_probs = predictions.mean(dim=0) entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1) return entropy else: raise ValueError(f"Unknown method: {method}")
def negative_log_likelihood( predictions: torch.Tensor, labels: torch.Tensor, ) -> float: """ Compute negative log-likelihood of predictions. Args: predictions: Mean predictions (batch_size, num_classes) labels: True labels (batch_size,) Returns: NLL score """ log_probs = torch.log(predictions + 1e-10) nll = -log_probs[torch.arange(len(labels)), labels].mean() return nll.item() def brier_score( predictions: torch.Tensor, labels: torch.Tensor, ) -> float: """ Compute Brier score for probabilistic predictions. Args: predictions: Predictions (batch_size, num_classes) labels: True labels (batch_size,) Returns: Brier score """ num_classes = predictions.size(-1) one_hot = torch.zeros_like(predictions) one_hot[torch.arange(len(labels)), labels] = 1.0 brier = ((predictions - one_hot) ** 2).sum(dim=-1).mean() return brier.item() def predictive_log_likelihood( predictions: torch.Tensor, labels: torch.Tensor, ) -> float: """ Compute predictive log-likelihood (averaged over ensemble). Args: predictions: Ensemble predictions (num_models, batch_size, num_classes) labels: True labels (batch_size,) Returns: Predictive log-likelihood """ # Average predictions mean_probs = predictions.mean(dim=0) # Log-likelihood log_probs = torch.log(mean_probs + 1e-10) ll = log_probs[torch.arange(len(labels)), labels].mean() return ll.item() def sharpness(predictions: torch.Tensor) -> float: """ Compute sharpness of probabilistic predictions. Sharpness measures how concentrated the predictive distribution is. Lower entropy = sharper predictions. Args: predictions: Predictions (batch_size, num_classes) Returns: Average entropy (lower = sharper) """ entropy = -(predictions * torch.log(predictions + 1e-10)).sum(dim=-1) return entropy.mean().item() __all__ = [ "ensemble_diversity", "uncertainty_quality", "disagreement", "negative_log_likelihood", "brier_score", "predictive_log_likelihood", "sharpness", ]