Source code for incerto.utils.visualization

"""
Visualization utilities for examples.

Common plotting functions to avoid duplication across examples.
"""

import logging

import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, List

logger = logging.getLogger(__name__)


[docs] def plot_training_curves( train_losses: List[float], val_losses: Optional[List[float]] = None, train_accs: Optional[List[float]] = None, val_accs: Optional[List[float]] = None, title: str = "Training Curves", save_path: Optional[str] = None, show: bool = True, ): """ Plot training and validation curves. Args: train_losses: Training losses per epoch val_losses: Validation losses per epoch (optional) train_accs: Training accuracies per epoch (optional) val_accs: Validation accuracies per epoch (optional) title: Plot title save_path: Path to save figure (optional) show: Whether to call plt.show() (default: True) """ fig, axes = plt.subplots(1, 2, figsize=(12, 4)) # Loss plot epochs = range(1, len(train_losses) + 1) axes[0].plot(epochs, train_losses, "b-", label="Train Loss") if val_losses is not None: axes[0].plot(epochs, val_losses, "r-", label="Val Loss") axes[0].set_xlabel("Epoch") axes[0].set_ylabel("Loss") axes[0].set_title("Loss") axes[0].legend() axes[0].grid(True, alpha=0.3) # Accuracy plot if train_accs is not None: axes[1].plot(epochs, train_accs, "b-", label="Train Acc") if val_accs is not None: axes[1].plot(epochs, val_accs, "r-", label="Val Acc") if train_accs is not None or val_accs is not None: axes[1].set_xlabel("Epoch") axes[1].set_ylabel("Accuracy (%)") axes[1].set_title("Accuracy") axes[1].legend() axes[1].grid(True, alpha=0.3) fig.suptitle(title) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") logger.info("Saved plot to %s", save_path) if show: plt.show()
[docs] def plot_uncertainty_distribution( uncertainties: torch.Tensor, correct_mask: torch.Tensor, title: str = "Uncertainty Distribution", xlabel: str = "Uncertainty", save_path: Optional[str] = None, show: bool = True, ): """ Plot uncertainty distribution for correct vs incorrect predictions. Args: uncertainties: Uncertainty scores (N,) correct_mask: Boolean mask for correct predictions (N,) title: Plot title xlabel: X-axis label save_path: Path to save figure (optional) show: Whether to call plt.show() (default: True) """ uncertainties = uncertainties.detach().cpu().numpy() correct_mask = correct_mask.detach().cpu().numpy() correct_unc = uncertainties[correct_mask] incorrect_unc = uncertainties[~correct_mask] fig, axes = plt.subplots(1, 2, figsize=(12, 4)) # Histogram bins = 30 axes[0].hist(correct_unc, bins=bins, alpha=0.6, label="Correct", color="green") axes[0].hist(incorrect_unc, bins=bins, alpha=0.6, label="Incorrect", color="red") axes[0].set_xlabel(xlabel) axes[0].set_ylabel("Count") axes[0].set_title("Histogram") axes[0].legend() axes[0].grid(True, alpha=0.3) # Box plot axes[1].boxplot( [correct_unc, incorrect_unc], labels=["Correct", "Incorrect"], patch_artist=True, ) axes[1].set_ylabel(xlabel) axes[1].set_title("Box Plot") axes[1].grid(True, alpha=0.3, axis="y") fig.suptitle(title) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") logger.info("Saved plot to %s", save_path) if show: plt.show()
[docs] def plot_2d_classification( X: np.ndarray, y: np.ndarray, model: Optional[torch.nn.Module] = None, device: str = "cpu", title: str = "Classification", save_path: Optional[str] = None, show: bool = True, ): """ Plot 2D classification data and decision boundary. Args: X: Input features (N, 2) y: Labels (N,) model: Trained model (optional) device: Device for model inference title: Plot title save_path: Path to save figure (optional) show: Whether to call plt.show() (default: True) """ plt.figure(figsize=(10, 8)) # Plot decision boundary if model provided if model is not None: model.eval() h = 0.02 # Step size in mesh x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) grid = torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()]).to(device) with torch.no_grad(): Z = model(grid).cpu().numpy() if Z.shape[1] > 1: # Multi-class Z = Z.argmax(axis=1) else: # Binary Z = (Z > 0).astype(int).flatten() Z = Z.reshape(xx.shape) plt.contourf(xx, yy, Z, alpha=0.3, cmap="RdYlBu") # Plot data points scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap="RdYlBu", edgecolors="k", s=50) plt.colorbar(scatter) plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.title(title) plt.grid(True, alpha=0.3) if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") logger.info("Saved plot to %s", save_path) if show: plt.show()
__all__ = [ "plot_training_curves", "plot_uncertainty_distribution", "plot_2d_classification", ]