Source code for incerto.active.acquisition

"""
Acquisition functions for active learning.

Acquisition functions score unlabeled samples based on their informativeness,
allowing selection of the most valuable samples for labeling.
"""

from __future__ import annotations
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F


[docs] class BaseAcquisition(ABC): """Base class for acquisition functions."""
[docs] @abstractmethod def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Compute acquisition scores for unlabeled samples. Args: model: Trained model x: Unlabeled samples ``(N, ...)`` **kwargs: Additional arguments Returns: Acquisition scores (N,), higher = more informative """ pass
[docs] class RandomAcquisition(BaseAcquisition): """ Random sampling (baseline). Samples are selected uniformly at random. """
[docs] def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """Return random scores.""" return torch.rand(len(x), device=x.device)
[docs] class EntropyAcquisition(BaseAcquisition): """ Entropy-based acquisition. Select samples with highest predictive entropy (most uncertain). Reference: Shannon, "A Mathematical Theory of Communication" (1948) """
[docs] @torch.no_grad() def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """Compute entropy scores.""" was_training = model.training try: model.eval() logits = model(x) probs = F.softmax(logits, dim=-1) # Entropy: -∑ p log p entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1) return entropy finally: model.train(was_training)
[docs] class LeastConfidenceAcquisition(BaseAcquisition): """ Least confidence acquisition. Select samples where the model is least confident about its prediction. Reference: Lewis & Gale, "A Sequential Algorithm for Training Text Classifiers" (1994) """
[docs] @torch.no_grad() def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """Compute least confidence scores.""" was_training = model.training try: model.eval() logits = model(x) probs = F.softmax(logits, dim=-1) # Least confidence: 1 - max(p) max_probs, _ = probs.max(dim=-1) return 1.0 - max_probs finally: model.train(was_training)
[docs] class MarginAcquisition(BaseAcquisition): """ Margin sampling acquisition. Select samples with smallest margin between top-2 predictions. Reference: Scheffer et al., "Active Hidden Markov Models for Information Extraction" (2001) """
[docs] @torch.no_grad() def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """Compute margin scores.""" was_training = model.training try: model.eval() logits = model(x) probs = F.softmax(logits, dim=-1) # Guard: need at least 2 classes for margin if probs.size(-1) < 2: # Single class: no margin, return zeros (no uncertainty) return torch.zeros(len(x), device=x.device) # Sort probabilities sorted_probs, _ = torch.sort(probs, descending=True, dim=-1) # Margin: difference between top-2 margin = sorted_probs[:, 0] - sorted_probs[:, 1] # Return negative margin (higher = smaller margin = more uncertain) return -margin finally: model.train(was_training)
[docs] class BALDAcquisition(BaseAcquisition): """ Bayesian Active Learning by Disagreement (BALD). Selects samples that maximize the mutual information between predictions and model parameters. I[y;θ|x] = H[y|x] - E_θ[H[y|x,θ]] Reference: Houlsby et al., "Bayesian Active Learning for Classification" (ICML 2011) """
[docs] def __init__(self, num_samples: int = 20): """ Initialize BALD acquisition. Args: num_samples: Number of MC samples for Bayesian inference """ self.num_samples = num_samples
[docs] @torch.no_grad() def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """Compute BALD scores.""" was_training = model.training try: # Enable dropout for MC sampling model.train() # Collect predictions predictions = [] for _ in range(self.num_samples): logits = model(x) probs = F.softmax(logits, dim=-1) predictions.append(probs) predictions = torch.stack( predictions ) # (num_samples, batch_size, num_classes) # Expected entropy: E_θ[H[y|x,θ]] expected_entropy = ( -(predictions * torch.log(predictions + 1e-10)).sum(dim=-1).mean(dim=0) ) # Entropy of mean: H[E_θ[y|x,θ]] mean_probs = predictions.mean(dim=0) entropy_of_mean = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1) # Mutual information (BALD score) bald = entropy_of_mean - expected_entropy return bald finally: model.train(was_training)
[docs] class VarianceRatioAcquisition(BaseAcquisition): """ Variance ratio acquisition. Measures disagreement as: 1 - (mode_count / num_samples) Reference: Freeman, "Learning to be uncertain" (1970) """
[docs] def __init__(self, num_samples: int = 20): """ Initialize variance ratio acquisition. Args: num_samples: Number of samples for estimation """ self.num_samples = num_samples
[docs] @torch.no_grad() def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """Compute variance ratio scores.""" was_training = model.training try: model.train() # Collect predictions predictions = [] for _ in range(self.num_samples): logits = model(x) pred_labels = logits.argmax(dim=-1) predictions.append(pred_labels) predictions = torch.stack(predictions) # (num_samples, batch_size) # Compute mode frequency (vectorized) mode_values = torch.mode(predictions, dim=0).values # (batch_size,) mode_freq = (predictions == mode_values.unsqueeze(0)).float().sum( dim=0 ) / self.num_samples return 1.0 - mode_freq finally: model.train(was_training)
[docs] class MeanSTDAcquisition(BaseAcquisition): """ Mean standard deviation acquisition. Uses the average standard deviation across output probabilities as a measure of uncertainty. """
[docs] def __init__(self, num_samples: int = 20): """ Initialize mean STD acquisition. Args: num_samples: Number of samples for estimation """ self.num_samples = num_samples
[docs] @torch.no_grad() def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """Compute mean STD scores.""" was_training = model.training try: model.train() # Collect predictions predictions = [] for _ in range(self.num_samples): logits = model(x) probs = F.softmax(logits, dim=-1) predictions.append(probs) predictions = torch.stack( predictions ) # (num_samples, batch_size, num_classes) # Compute standard deviation std = predictions.std(dim=0) # Average over classes mean_std = std.mean(dim=-1) return mean_std finally: model.train(was_training)
[docs] class BatchBALDAcquisition(BaseAcquisition): """ Approximate BatchBALD via individual BALD scores. Full BatchBALD (Kirsch et al., NeurIPS 2019) greedily selects batches that jointly maximise information gain by computing joint entropies. This implementation returns per-sample BALD scores as a tractable approximation; for true batch-aware selection, pair with a diversity-aware strategy (e.g., ``DiversitySampling``). Reference: Kirsch et al., "BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning" (NeurIPS 2019) """
[docs] def __init__(self, num_samples: int = 20): """ Initialize BatchBALD acquisition. Args: num_samples: Number of MC samples """ self.num_samples = num_samples
[docs] @torch.no_grad() def score( self, model: torch.nn.Module, x: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Compute per-sample BALD scores as a BatchBALD approximation. Note: This is a simplified version that returns individual BALD scores. Full BatchBALD requires joint entropy computation which is computationally expensive. For batch-aware diversity, pair with a diversity-aware strategy like ``DiversitySampling``. Args: model: Model x: Unlabeled samples Returns: Per-sample BALD scores """ was_training = model.training try: model.train() # Collect predictions predictions = [] for _ in range(self.num_samples): logits = model(x) probs = F.softmax(logits, dim=-1) predictions.append(probs) predictions = torch.stack(predictions) # Compute individual BALD scores expected_entropy = ( -(predictions * torch.log(predictions + 1e-10)).sum(dim=-1).mean(dim=0) ) mean_probs = predictions.mean(dim=0) entropy_of_mean = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1) bald_scores = entropy_of_mean - expected_entropy return bald_scores finally: model.train(was_training)
__all__ = [ "BaseAcquisition", "RandomAcquisition", "EntropyAcquisition", "LeastConfidenceAcquisition", "MarginAcquisition", "BALDAcquisition", "VarianceRatioAcquisition", "MeanSTDAcquisition", "BatchBALDAcquisition", ]