"""
Bayesian Deep Learning Methods.
Implementations of popular Bayesian approaches for uncertainty quantification
in neural networks.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional, Callable, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
[docs]
class BaseBayesianMethod(nn.Module, ABC):
"""
Abstract base class for Bayesian deep learning methods.
All Bayesian methods produce uncertainty estimates by maintaining or
sampling from a distribution over model parameters (or predictions).
They follow a common interface:
- ``predict(x, return_samples=False)`` returns ``(mean, variance)``
or ``(mean, variance, samples)`` when ``return_samples=True``.
The returned variance captures **epistemic** (model) uncertainty —
disagreement between posterior samples. Aleatoric (data) uncertainty
is not directly modelled by this interface.
"""
[docs]
@abstractmethod
def predict(
self,
x: torch.Tensor,
return_samples: bool = False,
**kwargs,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Prediction with uncertainty estimation.
Args:
x: Input tensor.
return_samples: If True, also return individual samples/predictions.
Returns:
``(mean, variance)`` or ``(mean, variance, samples)`` when
*return_samples* is True. Variance is computed as disagreement
across posterior samples and captures epistemic (model) uncertainty.
"""
...
[docs]
class MCDropout(BaseBayesianMethod):
"""
Monte Carlo Dropout for uncertainty estimation.
Applies dropout at test time and aggregates predictions from multiple
forward passes to estimate predictive uncertainty.
Reference:
Gal & Ghahramani, "Dropout as a Bayesian Approximation" (ICML 2016)
Args:
model: Base neural network
num_samples: Number of MC samples for inference (default: 20)
dropout_rate: Dropout probability (default: 0.1)
Example:
>>> backbone = ResNet18(num_classes=10)
>>> mc_model = MCDropout(backbone, num_samples=20)
>>> mean, variance = mc_model.predict(x)
"""
[docs]
def __init__(
self,
model: nn.Module,
num_samples: int = 20,
dropout_rate: float = 0.1,
):
super().__init__()
if dropout_rate is not None and not (0 <= dropout_rate < 1):
raise ValueError(f"dropout_rate must be in [0, 1), got {dropout_rate}")
self.model = model
self.num_samples = num_samples
self.dropout_rate = dropout_rate
# Enable dropout in all dropout layers
self._enable_dropout()
def _enable_dropout(self):
"""Enable dropout at test time and optionally set dropout rate."""
found = False
for module in self.model.modules():
if isinstance(
module, (nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.AlphaDropout)
):
module.train()
# Override dropout rate if specified
if self.dropout_rate is not None:
module.p = self.dropout_rate
found = True
if not found:
import warnings
warnings.warn(
"MCDropout: no nn.Dropout layers found in the model. "
"All forward samples will be identical, yielding zero variance.",
stacklevel=2,
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Single forward pass (for training)."""
return self.model(x)
[docs]
@torch.no_grad()
def predict(
self,
x: torch.Tensor,
return_samples: bool = False,
normalize_output: bool = True,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Monte Carlo prediction with uncertainty estimation.
Args:
x: Input tensor ``(N, ...)``
return_samples: If True, return all MC samples
normalize_output: If True, apply softmax to 2-D multi-column
outputs (i.e. treat them as logits). Set to False when
the model already returns probabilities or when outputs
are not classification logits.
Returns:
Tuple of (mean_prediction, predictive_variance)
If return_samples=True: (mean, variance, samples)
"""
self._enable_dropout()
samples = []
for _ in range(self.num_samples):
output = self.model(x)
# Convert logits to probabilities for classification
if normalize_output and output.dim() == 2 and output.size(-1) > 1:
output = F.softmax(output, dim=-1)
samples.append(output)
samples = torch.stack(samples) # (num_samples, batch_size, *)
# Compute mean and variance
mean = samples.mean(dim=0)
variance = samples.var(dim=0)
if return_samples:
return mean, variance, samples
return mean, variance
[docs]
def predict_entropy(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute predictive entropy (total uncertainty).
Args:
x: Input tensor
Returns:
Predictive entropy for each sample
"""
mean_probs, _ = self.predict(x)
entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1)
return entropy
[docs]
class DeepEnsemble(BaseBayesianMethod):
"""
Deep Ensembles for uncertainty quantification.
Trains multiple neural networks independently and aggregates their
predictions. This is one of the most effective methods for uncertainty
estimation in deep learning.
Reference:
Lakshminarayanan et al., "Simple and Scalable Predictive Uncertainty
Estimation using Deep Ensembles" (NeurIPS 2017)
Args:
model_fn: Function that creates a new model instance
num_models: Number of ensemble members (default: 5)
Example:
>>> def create_model():
... return ResNet18(num_classes=10)
>>> ensemble = DeepEnsemble(create_model, num_models=5)
>>> # Train each model separately
>>> for i in range(5):
... train_model(ensemble.models[i], train_loader)
>>> mean, variance = ensemble.predict(x)
"""
[docs]
def __init__(
self,
model_fn: Callable[[], nn.Module],
num_models: int = 5,
):
super().__init__()
if not callable(model_fn):
raise TypeError("model_fn must be callable")
if num_models < 1:
raise ValueError(f"num_models must be >= 1, got {num_models}")
self.num_models = num_models
self.models = nn.ModuleList([model_fn() for _ in range(num_models)])
[docs]
def forward(self, x: torch.Tensor, model_idx: Optional[int] = None) -> torch.Tensor:
"""
Forward pass through a specific model or all models.
Args:
x: Input tensor
model_idx: If specified, use only this model. Otherwise average all.
Returns:
Model output(s)
"""
if model_idx is not None:
return self.models[model_idx](x)
# Average predictions from all models
outputs = [model(x) for model in self.models]
return torch.stack(outputs).mean(dim=0)
[docs]
@torch.no_grad()
def predict(
self,
x: torch.Tensor,
return_samples: bool = False,
return_all: bool = False,
normalize_output: bool = True,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Ensemble prediction with uncertainty.
Args:
x: Input tensor ``(N, ...)``
return_samples: If True, return all individual predictions
return_all: Deprecated alias for return_samples
normalize_output: If True, apply softmax to 2-D multi-column
outputs (i.e. treat them as logits). Set to False when
the model already returns probabilities or when outputs
are not classification logits.
Returns:
Tuple of (mean_prediction, epistemic_variance).
Variance is computed as disagreement across ensemble members
and captures epistemic (model) uncertainty.
If return_samples=True: (mean, variance, all_predictions)
"""
_return = return_samples or return_all
predictions = []
for model in self.models:
model.eval()
output = model(x)
# Convert to probabilities for classification
if normalize_output and output.dim() == 2 and output.size(-1) > 1:
output = F.softmax(output, dim=-1)
predictions.append(output)
predictions = torch.stack(predictions) # (num_models, batch_size, *)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
if _return:
return mean, variance, predictions
return mean, variance
[docs]
def train_member(
self,
model_idx: int,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
num_epochs: int = 10,
device: str | None = None,
):
"""
Train a specific ensemble member.
Args:
model_idx: Index of model to train
train_loader: Training data loader
optimizer: Optimizer instance
criterion: Loss function
num_epochs: Number of training epochs
device: Device to train on (default: auto-detect CUDA/CPU)
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = self.models[model_idx].to(device)
model.train()
for epoch in range(num_epochs):
total_loss = 0
for batch_x, batch_y in train_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
outputs = model(batch_x)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
[docs]
def diversity(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute ensemble diversity (disagreement).
Args:
x: Input tensor
Returns:
Per-sample diversity score
"""
_, variance, _ = self.predict(x, return_samples=True)
return variance.mean(dim=-1)
[docs]
class SWAG(BaseBayesianMethod):
"""
Stochastic Weight Averaging - Gaussian (SWAG).
Approximates the posterior over weights using a Gaussian distribution
fitted to the trajectory of SGD. Efficient and scalable Bayesian
inference method.
Reference:
Maddox et al., "A Simple Baseline for Bayesian Uncertainty Estimation
in Deep Learning" (NeurIPS 2019)
Args:
model: Base neural network
num_samples: Number of samples for prediction (default: 20)
var_clamp: Variance clamping value (default: 1e-6)
Note:
This is a diagonal SWAG implementation. The full SWAG algorithm
also includes a low-rank covariance component, which is not
implemented here for simplicity and memory efficiency.
Example:
>>> model = ResNet18(num_classes=10)
>>> swag = SWAG(model)
>>> # Train normally, then collect SWAG statistics
>>> for epoch in range(start_swag, num_epochs):
... train_epoch(model, ...)
... swag.collect_model(model)
>>> mean, variance = swag.predict(x)
"""
[docs]
def __init__(
self,
model: nn.Module,
num_samples: int = 20,
var_clamp: float = 1e-6,
):
super().__init__()
self.model = model
self.num_samples = num_samples
self.var_clamp = var_clamp
# Statistics for diagonal SWAG
self.mean = {}
self.sq_mean = {}
self.n_models = 0
# Initialize statistics
for name, param in model.named_parameters():
self.mean[name] = torch.zeros_like(param.data)
self.sq_mean[name] = torch.zeros_like(param.data)
def _apply(self, fn):
"""Override to move SWAG statistics with .to() / .cuda() / etc."""
super()._apply(fn)
self.mean = {name: fn(tensor) for name, tensor in self.mean.items()}
self.sq_mean = {name: fn(tensor) for name, tensor in self.sq_mean.items()}
return self
[docs]
def collect_model(self, model: nn.Module):
"""
Collect model statistics for SWAG.
Call this method periodically during training (e.g., every epoch
after a warmup period) to collect weight statistics.
Args:
model: Current model to collect statistics from
"""
self.n_models += 1
for name, param in model.named_parameters():
# Update first moment
self.mean[name] = (self.n_models - 1) / self.n_models * self.mean[
name
] + 1.0 / self.n_models * param.data
# Update second moment
self.sq_mean[name] = (self.n_models - 1) / self.n_models * self.sq_mean[
name
] + 1.0 / self.n_models * param.data**2
[docs]
def sample_parameters(self) -> dict:
"""
Sample parameters from the SWAG posterior.
Returns:
Dictionary of sampled parameters
"""
sampled_params = {}
for name in self.mean.keys():
# Compute variance
var = torch.clamp(
self.sq_mean[name] - self.mean[name] ** 2, min=self.var_clamp
)
# Sample from Gaussian
std = torch.sqrt(var)
sampled_params[name] = self.mean[name] + torch.randn_like(std) * std
return sampled_params
[docs]
def state_dict(self) -> dict:
"""
Return state dictionary for serialization.
Returns:
State dictionary containing model state and SWAG statistics.
"""
state = super().state_dict()
state["_swag_mean"] = self.mean
state["_swag_sq_mean"] = self.sq_mean
state["_swag_n_models"] = self.n_models
return state
[docs]
def load_state_dict(self, state_dict: dict, strict: bool = True):
"""
Load state dictionary.
Args:
state_dict: State dictionary to load.
strict: Whether to strictly enforce key matching.
"""
# Copy to avoid mutating caller's dict
state_dict = dict(state_dict)
# Extract SWAG-specific state
self.mean = state_dict.pop("_swag_mean", {})
self.sq_mean = state_dict.pop("_swag_sq_mean", {})
self.n_models = state_dict.pop("_swag_n_models", 0)
# Load remaining state (model parameters)
super().load_state_dict(state_dict, strict=strict)
[docs]
@torch.no_grad()
def predict(
self,
x: torch.Tensor,
return_samples: bool = False,
normalize_output: bool = True,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
SWAG prediction with uncertainty.
Args:
x: Input tensor ``(N, ...)``
return_samples: If True, return all sampled predictions
normalize_output: If True, apply softmax to 2-D multi-column
outputs (i.e. treat them as logits). Set to False when
the model already returns probabilities or when outputs
are not classification logits.
Returns:
Tuple of (mean_prediction, predictive_variance)
If return_samples=True: (mean, variance, samples)
"""
if self.n_models == 0:
from ..exceptions import NotFittedError
raise NotFittedError("No models collected. Call collect_model() first.")
predictions = []
# Save original parameters
original_params = {
name: param.data.clone() for name, param in self.model.named_parameters()
}
try:
# Sample and predict
for _ in range(self.num_samples):
sampled_params = self.sample_parameters()
# Load sampled parameters
for name, param in self.model.named_parameters():
param.data = sampled_params[name]
# Forward pass
output = self.model(x)
if normalize_output and output.dim() == 2 and output.size(-1) > 1:
output = F.softmax(output, dim=-1)
predictions.append(output)
finally:
# Restore original parameters even if forward pass fails
for name, param in self.model.named_parameters():
param.data = original_params[name]
predictions = torch.stack(predictions)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
if return_samples:
return mean, variance, predictions
return mean, variance
[docs]
class LaplaceApproximation(BaseBayesianMethod):
"""
Laplace Approximation for Bayesian Neural Networks.
Approximates the posterior over weights using a Gaussian centered at
the MAP estimate (trained weights) with covariance derived from the
Hessian of the loss function.
Reference:
MacKay, "A Practical Bayesian Framework for Backpropagation Networks" (1992)
Daxberger et al., "Laplace Redux" (NeurIPS 2021)
Args:
model: Trained neural network (at MAP estimate)
likelihood: Likelihood type ('classification' or 'regression')
prior_precision: Prior precision (inverse variance) (default: 1.0)
Example:
>>> model = ResNet18(num_classes=10)
>>> # Train model to convergence
>>> train_model(model, train_loader)
>>> laplace = LaplaceApproximation(model, likelihood='classification')
>>> laplace.fit(train_loader)
>>> mean, variance = laplace.predict(x)
"""
[docs]
def __init__(
self,
model: nn.Module,
likelihood: str = "classification",
prior_precision: float = 1.0,
num_samples: int = 20,
):
super().__init__()
self.model = model
self.likelihood = likelihood
self.prior_precision = prior_precision
self.num_samples = num_samples
# Will store the posterior covariance
self.posterior_precision = None
self.mean = None
def _apply(self, fn):
"""Override to move Laplace statistics with .to() / .cuda() / etc."""
super()._apply(fn)
if self.mean is not None:
self.mean = {name: fn(tensor) for name, tensor in self.mean.items()}
if self.posterior_precision is not None:
self.posterior_precision = {
name: fn(tensor) for name, tensor in self.posterior_precision.items()
}
return self
def _compute_hessian_diag(
self,
data_loader: DataLoader,
device: str | None = None,
) -> dict:
"""
Compute diagonal of the empirical Fisher via per-sample gradients.
Each sample's gradient is squared individually and then averaged,
giving an unbiased estimate of the Fisher diagonal.
Args:
data_loader: Data loader for computing Hessian
device: Device to use (default: auto-detect CUDA/CPU)
Returns:
Dictionary of Hessian diagonal per parameter
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
hessian_diag = {}
# Initialize
for name, param in self.model.named_parameters():
hessian_diag[name] = torch.zeros_like(param.data)
self.model.to(device)
self.model.eval()
n_samples = 0
for batch_x, batch_y in data_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
# Per-sample gradients for an unbiased Fisher diagonal
for i in range(len(batch_x)):
self.model.zero_grad()
output = self.model(batch_x[i : i + 1])
if self.likelihood == "classification":
loss = F.cross_entropy(output, batch_y[i : i + 1])
else: # regression — Gaussian NLL with unit variance
loss = 0.5 * F.mse_loss(
output.squeeze(), batch_y[i : i + 1].float()
)
loss.backward()
for name, param in self.model.named_parameters():
if param.grad is not None:
hessian_diag[name] += param.grad.data**2
n_samples += 1
# Average over all samples
for name in hessian_diag:
hessian_diag[name] /= n_samples
return hessian_diag
[docs]
def fit(self, data_loader: DataLoader, device: str | None = None):
"""
Fit Laplace approximation by computing Hessian.
Args:
data_loader: Data loader for computing Hessian
device: Device to use (default: auto-detect CUDA/CPU)
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Store MAP estimate (current weights)
self.mean = {
name: param.data.clone() for name, param in self.model.named_parameters()
}
# Compute Hessian diagonal
hessian_diag = self._compute_hessian_diag(data_loader, device)
# Posterior precision = prior precision + Hessian
self.posterior_precision = {
name: self.prior_precision + hessian_diag[name] for name in hessian_diag
}
[docs]
@torch.no_grad()
def predict(
self,
x: torch.Tensor,
return_samples: bool = False,
normalize_output: bool = True,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Laplace prediction with uncertainty.
Args:
x: Input tensor ``(N, ...)``
return_samples: If True, return all sampled predictions
normalize_output: If True, apply softmax to 2-D multi-column
outputs (i.e. treat them as logits). Set to False when
the model already returns probabilities or when outputs
are not classification logits.
Returns:
Tuple of (mean_prediction, predictive_variance)
If return_samples=True: (mean, variance, samples)
"""
if self.mean is None:
from ..exceptions import NotFittedError
raise NotFittedError("Model not fitted. Call fit() first.")
predictions = []
# Save original parameters
original_params = {
name: param.data.clone() for name, param in self.model.named_parameters()
}
try:
# Sample from posterior
for _ in range(self.num_samples):
for name, param in self.model.named_parameters():
# Sample from Gaussian posterior
std = 1.0 / torch.sqrt(self.posterior_precision[name])
param.data = self.mean[name] + torch.randn_like(std) * std
output = self.model(x)
if normalize_output and output.dim() == 2 and output.size(-1) > 1:
output = F.softmax(output, dim=-1)
predictions.append(output)
finally:
# Restore original parameters even if forward pass fails
for name, param in self.model.named_parameters():
param.data = original_params[name]
predictions = torch.stack(predictions)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
if return_samples:
return mean, variance, predictions
return mean, variance
[docs]
def state_dict(self) -> dict:
"""
Return state dictionary for serialization.
Returns:
State dictionary containing model state and Laplace statistics.
"""
state = super().state_dict()
state["_laplace_mean"] = self.mean
state["_laplace_posterior_precision"] = self.posterior_precision
state["_laplace_likelihood"] = self.likelihood
state["_laplace_prior_precision"] = self.prior_precision
state["_laplace_num_samples"] = self.num_samples
return state
[docs]
def load_state_dict(self, state_dict: dict, strict: bool = True):
"""
Load state dictionary.
Args:
state_dict: State dictionary to load.
strict: Whether to strictly enforce key matching.
"""
# Copy to avoid mutating caller's dict
state_dict = dict(state_dict)
# Extract Laplace-specific state
self.mean = state_dict.pop("_laplace_mean", None)
self.posterior_precision = state_dict.pop("_laplace_posterior_precision", None)
self.likelihood = state_dict.pop("_laplace_likelihood", "classification")
self.prior_precision = state_dict.pop("_laplace_prior_precision", 1.0)
self.num_samples = state_dict.pop("_laplace_num_samples", 20)
# Load remaining state (model parameters)
super().load_state_dict(state_dict, strict=strict)
[docs]
class VariationalBayesNN(BaseBayesianMethod):
"""
Variational Bayesian Neural Network (Bayes by Backprop).
Learns a distribution over weights using variational inference.
Each weight has a learned mean and variance.
.. note::
``variational_loss`` currently uses cross-entropy and therefore
only supports **classification** tasks. For regression, provide
a custom training loop with an appropriate likelihood.
Reference:
Blundell et al., "Weight Uncertainty in Neural Networks" (ICML 2015)
Args:
in_features: Input dimension
hidden_sizes: List of hidden layer sizes
out_features: Output dimension
prior_std: Prior standard deviation (default: 1.0)
num_samples: Number of MC samples for prediction (default: 20)
Example:
>>> model = VariationalBayesNN(784, [512, 256], 10)
>>> # Train with variational loss
>>> for x, y in train_loader:
... loss = model.variational_loss(x, y, num_samples=10)
... loss.backward()
>>> mean, variance = model.predict(x)
"""
[docs]
def __init__(
self,
in_features: int,
hidden_sizes: List[int],
out_features: int,
prior_std: float = 1.0,
num_samples: int = 20,
):
super().__init__()
self.prior_std = prior_std
self.num_samples = num_samples
# Build network with Gaussian weights
layers = []
prev_size = in_features
for hidden_size in hidden_sizes:
layers.append(GaussianLinear(prev_size, hidden_size, prior_std=prior_std))
prev_size = hidden_size
layers.append(GaussianLinear(prev_size, out_features, prior_std=prior_std))
self.layers = nn.ModuleList(layers)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with sampled weights."""
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = F.relu(x)
return x
[docs]
def kl_divergence(self) -> torch.Tensor:
"""Compute KL divergence between posterior and prior."""
kl = 0
for layer in self.layers:
kl += layer.kl_divergence()
return kl
[docs]
def variational_loss(
self,
x: torch.Tensor,
y: torch.Tensor,
num_samples: int = 10,
kl_weight: float = 1.0,
) -> torch.Tensor:
"""
Compute variational loss (ELBO).
Loss = E[NLL] + KL[q(w) || p(w)]
Args:
x: Input tensor
y: Target labels
num_samples: Number of samples for MC estimate
kl_weight: Weight for KL term
Returns:
Variational loss
"""
# Likelihood term (averaged over samples)
nll = 0
for _ in range(num_samples):
outputs = self.forward(x)
nll += F.cross_entropy(outputs, y)
nll /= num_samples
# KL term
kl = self.kl_divergence()
# ELBO
return nll + kl_weight * kl / len(x)
[docs]
@torch.no_grad()
def predict(
self,
x: torch.Tensor,
return_samples: bool = False,
normalize_output: bool = True,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Variational prediction with uncertainty.
Args:
x: Input tensor
return_samples: If True, return all sampled predictions
normalize_output: If True, apply softmax to 2-D multi-column
outputs (i.e. treat them as logits). Set to False when
the model already returns probabilities or when outputs
are not classification logits.
Returns:
Tuple of (mean_prediction, predictive_variance)
If return_samples=True: (mean, variance, samples)
"""
predictions = []
for _ in range(self.num_samples):
output = self.forward(x)
if normalize_output and output.dim() == 2 and output.size(-1) > 1:
output = F.softmax(output, dim=-1)
predictions.append(output)
predictions = torch.stack(predictions)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
if return_samples:
return mean, variance, predictions
return mean, variance
[docs]
class GaussianLinear(nn.Module):
"""
Linear layer with Gaussian weights for variational inference.
Each weight has a learnable mean (mu) and log-variance (rho).
"""
[docs]
def __init__(
self,
in_features: int,
out_features: int,
prior_std: float = 1.0,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_std = prior_std
# Initialize rho such that softplus(rho) ≈ prior_std
# This ensures initial KL ≈ 0 (posterior matches prior)
# inverse_softplus(x) = log(exp(x) - 1)
init_rho = torch.log(torch.tensor(prior_std).exp() - 1).item()
# Variational parameters - initialize mu from prior N(0, prior_std)
# This ensures initial KL ≈ 0 since posterior = prior at initialization
self.weight_mu = nn.Parameter(
torch.randn(out_features, in_features) * prior_std
)
self.weight_rho = nn.Parameter(
torch.full((out_features, in_features), init_rho)
)
self.bias_mu = nn.Parameter(torch.zeros(out_features))
self.bias_rho = nn.Parameter(torch.full((out_features,), init_rho))
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with sampled weights."""
# Sample weights (softplus is numerically stable for large rho)
weight_std = F.softplus(self.weight_rho)
weight = self.weight_mu + weight_std * torch.randn_like(weight_std)
# Sample bias
bias_std = F.softplus(self.bias_rho)
bias = self.bias_mu + bias_std * torch.randn_like(bias_std)
return F.linear(x, weight, bias)
[docs]
def kl_divergence(self) -> torch.Tensor:
"""KL divergence between posterior and prior."""
# Weight KL
weight_std = F.softplus(self.weight_rho)
weight_kl = (
torch.log(self.prior_std / weight_std)
+ (weight_std**2 + self.weight_mu**2) / (2 * self.prior_std**2)
- 0.5
).sum()
# Bias KL
bias_std = F.softplus(self.bias_rho)
bias_kl = (
torch.log(self.prior_std / bias_std)
+ (bias_std**2 + self.bias_mu**2) / (2 * self.prior_std**2)
- 0.5
).sum()
return weight_kl + bias_kl
__all__ = [
"BaseBayesianMethod",
"MCDropout",
"DeepEnsemble",
"SWAG",
"LaplaceApproximation",
"VariationalBayesNN",
"GaussianLinear",
]