Source code for incerto.bayesian.methods

"""
Bayesian Deep Learning Methods.

Implementations of popular Bayesian approaches for uncertainty quantification
in neural networks.
"""

from __future__ import annotations
from typing import List, Optional, Callable, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import copy


[docs] class MCDropout(nn.Module): """ Monte Carlo Dropout for uncertainty estimation. Applies dropout at test time and aggregates predictions from multiple forward passes to estimate predictive uncertainty. Reference: Gal & Ghahramani, "Dropout as a Bayesian Approximation" (ICML 2016) Args: model: Base neural network num_samples: Number of MC samples for inference (default: 20) dropout_rate: Dropout probability (default: 0.1) Example: >>> backbone = ResNet18(num_classes=10) >>> mc_model = MCDropout(backbone, num_samples=20) >>> mean, variance = mc_model.predict(x) """
[docs] def __init__( self, model: nn.Module, num_samples: int = 20, dropout_rate: float = 0.1, ): super().__init__() self.model = model self.num_samples = num_samples self.dropout_rate = dropout_rate # Enable dropout in all dropout layers self._enable_dropout()
def _enable_dropout(self): """Enable dropout at test time.""" for module in self.model.modules(): if isinstance(module, nn.Dropout): module.train()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Single forward pass (for training).""" return self.model(x)
[docs] @torch.no_grad() def predict( self, x: torch.Tensor, return_samples: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Monte Carlo prediction with uncertainty estimation. Args: x: Input tensor (N, *) return_samples: If True, return all MC samples Returns: Tuple of (mean_prediction, predictive_variance) If return_samples=True: (mean, variance, samples) """ self._enable_dropout() samples = [] for _ in range(self.num_samples): output = self.model(x) # Convert logits to probabilities for classification if output.dim() == 2 and output.size(-1) > 1: output = F.softmax(output, dim=-1) samples.append(output) samples = torch.stack(samples) # (num_samples, batch_size, *) # Compute mean and variance mean = samples.mean(dim=0) variance = samples.var(dim=0) if return_samples: return mean, variance, samples return mean, variance
[docs] def predict_entropy(self, x: torch.Tensor) -> torch.Tensor: """ Compute predictive entropy (total uncertainty). Args: x: Input tensor Returns: Predictive entropy for each sample """ mean_probs, _ = self.predict(x) entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1) return entropy
[docs] def predict_mutual_information(self, x: torch.Tensor) -> torch.Tensor: """ Compute mutual information (epistemic uncertainty). MI = H[y|x] - E[H[y|x,θ]] Args: x: Input tensor Returns: Mutual information for each sample """ _, _, samples = self.predict(x, return_samples=True) # Expected entropy: E[H[y|x,θ]] expected_entropy = ( -(samples * torch.log(samples + 1e-10)).sum(dim=-1).mean(dim=0) ) # Entropy of mean: H[E[y|x,θ]] mean_probs = samples.mean(dim=0) entropy_of_mean = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1) # Mutual information mutual_info = entropy_of_mean - expected_entropy return mutual_info
[docs] class DeepEnsemble(nn.Module): """ Deep Ensembles for uncertainty quantification. Trains multiple neural networks independently and aggregates their predictions. This is one of the most effective methods for uncertainty estimation in deep learning. Reference: Lakshminarayanan et al., "Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles" (NeurIPS 2017) Args: model_fn: Function that creates a new model instance num_models: Number of ensemble members (default: 5) Example: >>> def create_model(): ... return ResNet18(num_classes=10) >>> ensemble = DeepEnsemble(create_model, num_models=5) >>> # Train each model separately >>> for i in range(5): ... train_model(ensemble.models[i], train_loader) >>> mean, variance = ensemble.predict(x) """
[docs] def __init__( self, model_fn: Callable[[], nn.Module], num_models: int = 5, ): super().__init__() self.num_models = num_models self.models = nn.ModuleList([model_fn() for _ in range(num_models)])
[docs] def forward(self, x: torch.Tensor, model_idx: Optional[int] = None) -> torch.Tensor: """ Forward pass through a specific model or all models. Args: x: Input tensor model_idx: If specified, use only this model. Otherwise average all. Returns: Model output(s) """ if model_idx is not None: return self.models[model_idx](x) # Average predictions from all models outputs = [model(x) for model in self.models] return torch.stack(outputs).mean(dim=0)
[docs] @torch.no_grad() def predict( self, x: torch.Tensor, return_all: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Ensemble prediction with uncertainty. Args: x: Input tensor (N, *) return_all: If True, return all individual predictions Returns: Tuple of (mean_prediction, predictive_variance) If return_all=True: (mean, variance, all_predictions) """ predictions = [] for model in self.models: model.eval() output = model(x) # Convert to probabilities for classification if output.dim() == 2 and output.size(-1) > 1: output = F.softmax(output, dim=-1) predictions.append(output) predictions = torch.stack(predictions) # (num_models, batch_size, *) mean = predictions.mean(dim=0) variance = predictions.var(dim=0) if return_all: return mean, variance, predictions return mean, variance
[docs] def train_member( self, model_idx: int, train_loader: DataLoader, optimizer: torch.optim.Optimizer, criterion: nn.Module, num_epochs: int = 10, device: str = "cuda", ): """ Train a specific ensemble member. Args: model_idx: Index of model to train train_loader: Training data loader optimizer: Optimizer instance criterion: Loss function num_epochs: Number of training epochs device: Device to train on """ model = self.models[model_idx].to(device) model.train() for epoch in range(num_epochs): total_loss = 0 for batch_x, batch_y in train_loader: batch_x, batch_y = batch_x.to(device), batch_y.to(device) optimizer.zero_grad() outputs = model(batch_x) loss = criterion(outputs, batch_y) loss.backward() optimizer.step() total_loss += loss.item()
[docs] def diversity(self, x: torch.Tensor) -> torch.Tensor: """ Compute ensemble diversity (disagreement). Args: x: Input tensor Returns: Per-sample diversity score """ _, variance, _ = self.predict(x, return_all=True) return variance.mean(dim=-1)
[docs] class SWAG(nn.Module): """ Stochastic Weight Averaging - Gaussian (SWAG). Approximates the posterior over weights using a Gaussian distribution fitted to the trajectory of SGD. Efficient and scalable Bayesian inference method. Reference: Maddox et al., "A Simple Baseline for Bayesian Uncertainty Estimation in Deep Learning" (NeurIPS 2019) Args: model: Base neural network num_samples: Number of samples for prediction (default: 20) max_models: Maximum number of models to store (default: 20) var_clamp: Variance clamping value (default: 1e-6) Example: >>> model = ResNet18(num_classes=10) >>> swag = SWAG(model) >>> # Train normally, then collect SWAG statistics >>> for epoch in range(start_swag, num_epochs): ... train_epoch(model, ...) ... swag.collect_model(model) >>> mean, variance = swag.predict(x) """
[docs] def __init__( self, model: nn.Module, num_samples: int = 20, max_models: int = 20, var_clamp: float = 1e-6, ): super().__init__() self.model = model self.num_samples = num_samples self.max_models = max_models self.var_clamp = var_clamp # Statistics for SWAG self.mean = {} self.sq_mean = {} self.cov_mat_sqrt = {} self.n_models = 0 # Initialize statistics for name, param in model.named_parameters(): self.mean[name] = torch.zeros_like(param.data) self.sq_mean[name] = torch.zeros_like(param.data)
[docs] def collect_model(self, model: nn.Module): """ Collect model statistics for SWAG. Call this method periodically during training (e.g., every epoch after a warmup period) to collect weight statistics. Args: model: Current model to collect statistics from """ self.n_models += 1 for name, param in model.named_parameters(): # Update first moment self.mean[name] = (self.n_models - 1) / self.n_models * self.mean[ name ] + 1.0 / self.n_models * param.data # Update second moment self.sq_mean[name] = (self.n_models - 1) / self.n_models * self.sq_mean[ name ] + 1.0 / self.n_models * param.data**2
[docs] def sample_parameters(self) -> dict: """ Sample parameters from the SWAG posterior. Returns: Dictionary of sampled parameters """ sampled_params = {} for name in self.mean.keys(): # Compute variance var = torch.clamp( self.sq_mean[name] - self.mean[name] ** 2, min=self.var_clamp ) # Sample from Gaussian std = torch.sqrt(var) sampled_params[name] = self.mean[name] + torch.randn_like(std) * std return sampled_params
[docs] @torch.no_grad() def predict( self, x: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ SWAG prediction with uncertainty. Args: x: Input tensor (N, *) Returns: Tuple of (mean_prediction, predictive_variance) """ if self.n_models == 0: raise RuntimeError("No models collected. Call collect_model() first.") predictions = [] # Save original parameters original_params = { name: param.data.clone() for name, param in self.model.named_parameters() } # Sample and predict for _ in range(self.num_samples): sampled_params = self.sample_parameters() # Load sampled parameters for name, param in self.model.named_parameters(): param.data = sampled_params[name] # Forward pass output = self.model(x) if output.dim() == 2 and output.size(-1) > 1: output = F.softmax(output, dim=-1) predictions.append(output) # Restore original parameters for name, param in self.model.named_parameters(): param.data = original_params[name] predictions = torch.stack(predictions) mean = predictions.mean(dim=0) variance = predictions.var(dim=0) return mean, variance
[docs] class LaplaceApproximation(nn.Module): """ Laplace Approximation for Bayesian Neural Networks. Approximates the posterior over weights using a Gaussian centered at the MAP estimate (trained weights) with covariance derived from the Hessian of the loss function. Reference: MacKay, "A Practical Bayesian Framework for Backpropagation Networks" (1992) Daxberger et al., "Laplace Redux" (NeurIPS 2021) Args: model: Trained neural network (at MAP estimate) likelihood: Likelihood type ('classification' or 'regression') prior_precision: Prior precision (inverse variance) (default: 1.0) Example: >>> model = ResNet18(num_classes=10) >>> # Train model to convergence >>> train_model(model, train_loader) >>> laplace = LaplaceApproximation(model, likelihood='classification') >>> laplace.fit(train_loader) >>> mean, variance = laplace.predict(x) """
[docs] def __init__( self, model: nn.Module, likelihood: str = "classification", prior_precision: float = 1.0, num_samples: int = 20, ): super().__init__() self.model = model self.likelihood = likelihood self.prior_precision = prior_precision self.num_samples = num_samples # Will store the posterior covariance self.posterior_precision = None self.mean = None
def _compute_hessian_diag( self, data_loader: DataLoader, device: str = "cuda", ) -> dict: """ Compute diagonal of Hessian (simplified Laplace). Args: data_loader: Data loader for computing Hessian device: Device to use Returns: Dictionary of Hessian diagonal per parameter """ hessian_diag = {} # Initialize for name, param in self.model.named_parameters(): hessian_diag[name] = torch.zeros_like(param.data) self.model.to(device) self.model.eval() for batch_x, batch_y in data_loader: batch_x, batch_y = batch_x.to(device), batch_y.to(device) # Forward pass outputs = self.model(batch_x) if self.likelihood == "classification": probs = F.softmax(outputs, dim=-1) # For classification: diagonal of Hessian is p(1-p) for name, param in self.model.named_parameters(): if param.grad is not None: param.grad.zero_() # Compute gradient for each output loss = F.cross_entropy(outputs, batch_y) loss.backward() for name, param in self.model.named_parameters(): if param.grad is not None: hessian_diag[name] += param.grad.data**2 else: # regression # For regression with Gaussian likelihood loss = F.mse_loss(outputs.squeeze(), batch_y.float()) loss.backward() for name, param in self.model.named_parameters(): if param.grad is not None: hessian_diag[name] += param.grad.data**2 # Average over dataset num_batches = len(data_loader) for name in hessian_diag: hessian_diag[name] /= num_batches return hessian_diag
[docs] def fit(self, data_loader: DataLoader, device: str = "cuda"): """ Fit Laplace approximation by computing Hessian. Args: data_loader: Data loader for computing Hessian device: Device to use """ # Store MAP estimate (current weights) self.mean = { name: param.data.clone() for name, param in self.model.named_parameters() } # Compute Hessian diagonal hessian_diag = self._compute_hessian_diag(data_loader, device) # Posterior precision = prior precision + Hessian self.posterior_precision = { name: self.prior_precision + hessian_diag[name] for name in hessian_diag }
[docs] @torch.no_grad() def predict( self, x: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Laplace prediction with uncertainty. Args: x: Input tensor (N, *) Returns: Tuple of (mean_prediction, predictive_variance) """ if self.mean is None: raise RuntimeError("Model not fitted. Call fit() first.") predictions = [] # Save original parameters original_params = { name: param.data.clone() for name, param in self.model.named_parameters() } # Sample from posterior for _ in range(self.num_samples): for name, param in self.model.named_parameters(): # Sample from Gaussian posterior std = 1.0 / torch.sqrt(self.posterior_precision[name]) param.data = self.mean[name] + torch.randn_like(std) * std output = self.model(x) if output.dim() == 2 and output.size(-1) > 1: output = F.softmax(output, dim=-1) predictions.append(output) # Restore original parameters for name, param in self.model.named_parameters(): param.data = original_params[name] predictions = torch.stack(predictions) mean = predictions.mean(dim=0) variance = predictions.var(dim=0) return mean, variance
class VariationalBayesNN(nn.Module): """ Variational Bayesian Neural Network (Bayes by Backprop). Learns a distribution over weights using variational inference. Each weight has a learned mean and variance. Reference: Blundell et al., "Weight Uncertainty in Neural Networks" (ICML 2015) Args: in_features: Input dimension hidden_sizes: List of hidden layer sizes out_features: Output dimension prior_std: Prior standard deviation (default: 1.0) Example: >>> model = VariationalBayesNN(784, [512, 256], 10) >>> # Train with variational loss >>> for x, y in train_loader: ... loss = model.variational_loss(x, y, num_samples=10) ... loss.backward() >>> mean, variance = model.predict(x) """ def __init__( self, in_features: int, hidden_sizes: List[int], out_features: int, prior_std: float = 1.0, ): super().__init__() self.prior_std = prior_std # Build network with Gaussian weights layers = [] prev_size = in_features for hidden_size in hidden_sizes: layers.append(GaussianLinear(prev_size, hidden_size, prior_std=prior_std)) prev_size = hidden_size layers.append(GaussianLinear(prev_size, out_features, prior_std=prior_std)) self.layers = nn.ModuleList(layers) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with sampled weights.""" for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: x = F.relu(x) return x def kl_divergence(self) -> torch.Tensor: """Compute KL divergence between posterior and prior.""" kl = 0 for layer in self.layers: kl += layer.kl_divergence() return kl def variational_loss( self, x: torch.Tensor, y: torch.Tensor, num_samples: int = 10, kl_weight: float = 1.0, ) -> torch.Tensor: """ Compute variational loss (ELBO). Loss = E[NLL] + KL[q(w) || p(w)] Args: x: Input tensor y: Target labels num_samples: Number of samples for MC estimate kl_weight: Weight for KL term Returns: Variational loss """ # Likelihood term (averaged over samples) nll = 0 for _ in range(num_samples): outputs = self.forward(x) nll += F.cross_entropy(outputs, y) nll /= num_samples # KL term kl = self.kl_divergence() # ELBO return nll + kl_weight * kl / len(x) @torch.no_grad() def predict( self, x: torch.Tensor, num_samples: int = 20, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Variational prediction with uncertainty. Args: x: Input tensor num_samples: Number of MC samples Returns: Tuple of (mean_prediction, predictive_variance) """ predictions = [] for _ in range(num_samples): output = self.forward(x) if output.dim() == 2 and output.size(-1) > 1: output = F.softmax(output, dim=-1) predictions.append(output) predictions = torch.stack(predictions) mean = predictions.mean(dim=0) variance = predictions.var(dim=0) return mean, variance class GaussianLinear(nn.Module): """ Linear layer with Gaussian weights for variational inference. Each weight has a learnable mean (mu) and log-variance (rho). """ def __init__( self, in_features: int, out_features: int, prior_std: float = 1.0, ): super().__init__() self.in_features = in_features self.out_features = out_features self.prior_std = prior_std # Variational parameters self.weight_mu = nn.Parameter(torch.randn(out_features, in_features) * 0.1) self.weight_rho = nn.Parameter(torch.randn(out_features, in_features) * 0.1) self.bias_mu = nn.Parameter(torch.randn(out_features) * 0.1) self.bias_rho = nn.Parameter(torch.randn(out_features) * 0.1) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with sampled weights.""" # Sample weights weight_std = torch.log1p(torch.exp(self.weight_rho)) weight = self.weight_mu + weight_std * torch.randn_like(weight_std) # Sample bias bias_std = torch.log1p(torch.exp(self.bias_rho)) bias = self.bias_mu + bias_std * torch.randn_like(bias_std) return F.linear(x, weight, bias) def kl_divergence(self) -> torch.Tensor: """KL divergence between posterior and prior.""" # Weight KL weight_std = torch.log1p(torch.exp(self.weight_rho)) weight_kl = ( torch.log(self.prior_std / weight_std) + (weight_std**2 + self.weight_mu**2) / (2 * self.prior_std**2) - 0.5 ).sum() # Bias KL bias_std = torch.log1p(torch.exp(self.bias_rho)) bias_kl = ( torch.log(self.prior_std / bias_std) + (bias_std**2 + self.bias_mu**2) / (2 * self.prior_std**2) - 0.5 ).sum() return weight_kl + bias_kl __all__ = [ "MCDropout", "DeepEnsemble", "SWAG", "LaplaceApproximation", "VariationalBayesNN", "GaussianLinear", ]