Source code for incerto.active.strategies

"""
Query strategies for active learning.

Strategies for selecting batches of samples to label, combining
acquisition functions with diversity and batch considerations.
"""

from __future__ import annotations
from typing import List, Optional
import torch
import torch.nn.functional as F
from .acquisition import BaseAcquisition


[docs] class UncertaintySampling: """ Uncertainty-based sampling strategy. Selects samples with highest acquisition scores. Args: acquisition_fn: Acquisition function to use batch_size: Number of samples to select per query Example: >>> acquisition = BALDAcquisition(num_samples=20) >>> strategy = UncertaintySampling(acquisition, batch_size=100) >>> indices = strategy.query(model, unlabeled_data) """
[docs] def __init__( self, acquisition_fn: BaseAcquisition, batch_size: int = 100, ): self.acquisition_fn = acquisition_fn self.batch_size = batch_size
[docs] def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, ) -> torch.Tensor: """ Query samples based on uncertainty. Args: model: Trained model x_unlabeled: Unlabeled data ``(N, ...)`` Returns: Indices of selected samples (batch_size,) """ # Compute acquisition scores scores = self.acquisition_fn.score(model, x_unlabeled) # Select top-k samples _, indices = torch.topk(scores, k=min(self.batch_size, len(scores))) return indices
[docs] class DiversitySampling: """ Diversity-based sampling with uncertainty. Balances uncertainty with diversity to avoid selecting redundant samples. Reference: Brinker, "Incorporating Diversity in Active Learning" (ICML 2003) Args: acquisition_fn: Acquisition function batch_size: Number of samples to select diversity_weight: Weight for diversity term (0-1) """
[docs] def __init__( self, acquisition_fn: BaseAcquisition, batch_size: int = 100, diversity_weight: float = 0.5, ): self.acquisition_fn = acquisition_fn self.batch_size = batch_size self.diversity_weight = diversity_weight
[docs] def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, features: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Query samples balancing uncertainty and diversity. Args: model: Trained model x_unlabeled: Unlabeled data features: Optional precomputed features for diversity computation Returns: Indices of selected samples """ # Compute uncertainty scores uncertainty_scores = self.acquisition_fn.score(model, x_unlabeled) # Extract features if not provided if features is None: was_training = model.training try: with torch.no_grad(): model.eval() # Use model embeddings as features features = model(x_unlabeled) if features.dim() > 2: features = features.flatten(1) finally: model.train(was_training) # Normalize scores uncertainty_scores = (uncertainty_scores - uncertainty_scores.min()) / ( uncertainty_scores.max() - uncertainty_scores.min() + 1e-10 ) # Greedy selection with diversity selected = [] available = torch.arange(len(x_unlabeled)) for _ in range(min(self.batch_size, len(x_unlabeled))): if len(selected) == 0: # First sample: select most uncertain idx = uncertainty_scores.argmax() selected.append(idx.item()) else: # Compute diversity scores selected_features = features[selected] diversity_scores = ( torch.cdist(features[available], selected_features) .min(dim=1) .values ) # Normalize diversity diversity_scores = (diversity_scores - diversity_scores.min()) / ( diversity_scores.max() - diversity_scores.min() + 1e-10 ) # Combined score combined = (1 - self.diversity_weight) * uncertainty_scores[ available ] + self.diversity_weight * diversity_scores # Select best best_idx = combined.argmax() selected.append(available[best_idx].item()) # Update available indices (use set for O(1) lookup) selected_set = set(selected) available = torch.tensor( [i for i in range(len(x_unlabeled)) if i not in selected_set], device=x_unlabeled.device, ) return torch.tensor(selected, device=x_unlabeled.device)
[docs] class CoreSetSelection: """ Core-Set selection for active learning. Selects samples that best represent the overall data distribution using k-center greedy algorithm. Reference: Sener & Savarese, "Active Learning for Convolutional Neural Networks: A Core-Set Approach" (ICLR 2018) Args: batch_size: Number of samples to select """
[docs] def __init__(self, batch_size: int = 100): self.batch_size = batch_size
[docs] def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, x_labeled: Optional[torch.Tensor] = None, features_unlabeled: Optional[torch.Tensor] = None, features_labeled: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Select core-set using greedy k-center. Args: model: Model for feature extraction x_unlabeled: Unlabeled data x_labeled: Labeled data (optional) features_unlabeled: Precomputed features for unlabeled data features_labeled: Precomputed features for labeled data Returns: Indices of selected samples """ # Extract features was_training = model.training try: if features_unlabeled is None: with torch.no_grad(): model.eval() features_unlabeled = model(x_unlabeled) if features_unlabeled.dim() > 2: features_unlabeled = features_unlabeled.flatten(1) if x_labeled is not None and features_labeled is None: with torch.no_grad(): model.eval() features_labeled = model(x_labeled) if features_labeled.dim() > 2: features_labeled = features_labeled.flatten(1) finally: model.train(was_training) # Greedy k-center if features_labeled is not None: # Start with labeled data as centers centers = features_labeled else: # Start with empty set centers = torch.empty( 0, features_unlabeled.size(1), device=features_unlabeled.device ) selected = [] for _ in range(min(self.batch_size, len(x_unlabeled))): if len(centers) == 0: # Select random first point idx = torch.randint(len(features_unlabeled), (1,)).item() else: # Compute distances to nearest center dists = torch.cdist(features_unlabeled, centers).min(dim=1).values # Select point farthest from centers idx = dists.argmax().item() selected.append(idx) # Add to centers centers = torch.cat([centers, features_unlabeled[idx : idx + 1]], dim=0) return torch.tensor(selected, device=x_unlabeled.device)
[docs] class BadgeSampling: """ BADGE (Batch Active learning by Diverse Gradient Embeddings). Combines gradient-based embeddings with k-MEANS++ for diverse batch selection. Reference: Ash et al., "Deep Batch Active Learning by Diverse, Uncertain Gradient Lower Bounds" (ICLR 2020) Args: batch_size: Number of samples to select """
[docs] def __init__(self, batch_size: int = 100): self.batch_size = batch_size
@staticmethod def _find_last_linear(model: torch.nn.Module) -> Optional[torch.nn.Linear]: """Find the last nn.Linear layer in the model.""" last_linear = None for module in model.modules(): if isinstance(module, torch.nn.Linear): last_linear = module return last_linear def _compute_gradient_embeddings( self, model: torch.nn.Module, x: torch.Tensor, ) -> torch.Tensor: """ Compute gradient embeddings per the BADGE paper. The gradient embedding for sample x is the gradient of the cross-entropy loss (with hallucinated label y_hat = argmax) w.r.t. the last linear layer's weight, which equals ``(p - e_{y_hat}) outer h`` where p is the softmax output, e_{y_hat} is one-hot of the predicted class, and h is the penultimate-layer features. """ was_training = model.training model.eval() last_linear = self._find_last_linear(model) if last_linear is None: model.train(was_training) raise ValueError( "BadgeSampling requires a model with at least one nn.Linear layer" ) # Hook to capture the input to the last linear layer captured_features = [] def hook_fn(module, input, output): captured_features.append(input[0].detach()) handle = last_linear.register_forward_hook(hook_fn) try: embeddings = [] with torch.no_grad(): for i in range(len(x)): captured_features.clear() output = model(x[i : i + 1]) probs = F.softmax(output, dim=-1) # (1, C) h = captured_features[0] # (1, D) penultimate features y_hat = probs.argmax(dim=-1) # (1,) # One-hot of predicted class e_y = torch.zeros_like(probs) e_y[0, y_hat] = 1.0 # Gradient embedding: (p - e_y) outer h, flattened diff = (probs - e_y).squeeze(0) # (C,) h_flat = h.squeeze(0) # (D,) embedding = torch.outer(diff, h_flat).flatten() # (C*D,) embeddings.append(embedding) return torch.stack(embeddings) finally: handle.remove() model.train(was_training)
[docs] def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, ) -> torch.Tensor: """ Select batch using BADGE. Args: model: Model x_unlabeled: Unlabeled data Returns: Indices of selected samples """ # Compute gradient embeddings embeddings = self._compute_gradient_embeddings(model, x_unlabeled) # k-MEANS++ initialization for diversity selected = [] available = torch.arange(len(embeddings)) # First point: random idx = torch.randint(len(embeddings), (1,)).item() selected.append(idx) # Remaining points: proportional to distance from nearest selected for _ in range(min(self.batch_size - 1, len(embeddings) - 1)): # Compute distances to nearest selected point dists = ( torch.cdist(embeddings[available], embeddings[selected]) .min(dim=1) .values ) # Sample proportionally to squared distance probs = (dists**2) / (dists**2).sum() idx = available[torch.multinomial(probs, 1).item()].item() selected.append(idx) # Update available (use set for O(1) lookup) selected_set = set(selected) available = torch.tensor( [i for i in range(len(embeddings)) if i not in selected_set], device=x_unlabeled.device, ) return torch.tensor(selected, device=x_unlabeled.device)
[docs] class QueryByCommittee: """ Query by Committee (QBC). Uses disagreement among an ensemble of models to select informative samples. Reference: Seung et al., "Query by Committee" (COLT 1992) Args: models: List of committee members batch_size: Number of samples to select disagreement: Disagreement measure ('vote_entropy' or 'kl') """
[docs] def __init__( self, models: List[torch.nn.Module], batch_size: int = 100, disagreement: str = "vote_entropy", ): self.models = models self.batch_size = batch_size self.disagreement = disagreement
[docs] @torch.no_grad() def query( self, model: Optional[torch.nn.Module] = None, x_unlabeled: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Query using committee disagreement. Args: model: Unused (committee members are provided at init). Accepted for interface compatibility with other strategies. x_unlabeled: Unlabeled data Returns: Indices of selected samples """ if x_unlabeled is None: raise ValueError("x_unlabeled is required") # Collect predictions from all committee members training_states = [m.training for m in self.models] predictions = [] try: for member in self.models: member.eval() logits = member(x_unlabeled) probs = F.softmax(logits, dim=-1) predictions.append(probs) finally: for member, was_training in zip(self.models, training_states): member.train(was_training) predictions = torch.stack(predictions) # (num_models, batch_size, num_classes) if self.disagreement == "vote_entropy": # Vote entropy: entropy of vote distribution votes = predictions.argmax(dim=-1) # (num_models, batch_size) # Count votes for each class num_classes = predictions.size(-1) vote_counts = torch.zeros( len(x_unlabeled), num_classes, device=x_unlabeled.device ) for i in range(len(self.models)): for j in range(len(x_unlabeled)): vote_counts[j, votes[i, j]] += 1 # Compute entropy of vote distribution vote_probs = vote_counts / len(self.models) scores = -(vote_probs * torch.log(vote_probs + 1e-10)).sum(dim=-1) elif self.disagreement == "kl": # Average KL divergence from mean mean_probs = predictions.mean(dim=0) scores = torch.zeros(len(x_unlabeled), device=x_unlabeled.device) for i in range(len(self.models)): kl = ( predictions[i] * ( torch.log(predictions[i] + 1e-10) - torch.log(mean_probs + 1e-10) ) ).sum(dim=-1) scores += kl scores /= len(self.models) else: raise ValueError(f"Unknown disagreement measure: {self.disagreement}") # Select top-k _, indices = torch.topk(scores, k=min(self.batch_size, len(scores))) return indices
__all__ = [ "UncertaintySampling", "DiversitySampling", "CoreSetSelection", "BadgeSampling", "QueryByCommittee", ]