Source code for incerto.active.strategies

"""
Query strategies for active learning.

Strategies for selecting batches of samples to label, combining
acquisition functions with diversity and batch considerations.
"""

from __future__ import annotations
from typing import List, Optional, Callable
import torch
import torch.nn.functional as F
import numpy as np
from .acquisition import BaseAcquisition, EntropyAcquisition


[docs] class UncertaintySampling: """ Uncertainty-based sampling strategy. Selects samples with highest acquisition scores. Args: acquisition_fn: Acquisition function to use batch_size: Number of samples to select per query Example: >>> acquisition = BALDAcquisition(num_samples=20) >>> strategy = UncertaintySampling(acquisition, batch_size=100) >>> indices = strategy.query(model, unlabeled_data) """
[docs] def __init__( self, acquisition_fn: BaseAcquisition, batch_size: int = 100, ): self.acquisition_fn = acquisition_fn self.batch_size = batch_size
[docs] def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, ) -> torch.Tensor: """ Query samples based on uncertainty. Args: model: Trained model x_unlabeled: Unlabeled data (N, *) Returns: Indices of selected samples (batch_size,) """ # Compute acquisition scores scores = self.acquisition_fn.score(model, x_unlabeled) # Select top-k samples _, indices = torch.topk(scores, k=min(self.batch_size, len(scores))) return indices
[docs] class DiversitySampling: """ Diversity-based sampling with uncertainty. Balances uncertainty with diversity to avoid selecting redundant samples. Reference: Brinker, "Incorporating Diversity in Active Learning" (ICML 2003) Args: acquisition_fn: Acquisition function batch_size: Number of samples to select diversity_weight: Weight for diversity term (0-1) """
[docs] def __init__( self, acquisition_fn: BaseAcquisition, batch_size: int = 100, diversity_weight: float = 0.5, ): self.acquisition_fn = acquisition_fn self.batch_size = batch_size self.diversity_weight = diversity_weight
[docs] def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, features: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Query samples balancing uncertainty and diversity. Args: model: Trained model x_unlabeled: Unlabeled data features: Optional precomputed features for diversity computation Returns: Indices of selected samples """ # Compute uncertainty scores uncertainty_scores = self.acquisition_fn.score(model, x_unlabeled) # Extract features if not provided if features is None: with torch.no_grad(): model.eval() # Use model embeddings as features features = model(x_unlabeled) if features.dim() > 2: features = features.flatten(1) # Normalize scores uncertainty_scores = (uncertainty_scores - uncertainty_scores.min()) / ( uncertainty_scores.max() - uncertainty_scores.min() + 1e-10 ) # Greedy selection with diversity selected = [] available = torch.arange(len(x_unlabeled)) for _ in range(min(self.batch_size, len(x_unlabeled))): if len(selected) == 0: # First sample: select most uncertain idx = uncertainty_scores.argmax() selected.append(idx.item()) else: # Compute diversity scores selected_features = features[selected] diversity_scores = ( torch.cdist(features[available], selected_features) .min(dim=1) .values ) # Normalize diversity diversity_scores = (diversity_scores - diversity_scores.min()) / ( diversity_scores.max() - diversity_scores.min() + 1e-10 ) # Combined score combined = (1 - self.diversity_weight) * uncertainty_scores[ available ] + self.diversity_weight * diversity_scores # Select best best_idx = combined.argmax() selected.append(available[best_idx].item()) # Update available indices available = torch.tensor( [i for i in range(len(x_unlabeled)) if i not in selected], device=x_unlabeled.device, ) return torch.tensor(selected, device=x_unlabeled.device)
class CoreSetSelection: """ Core-Set selection for active learning. Selects samples that best represent the overall data distribution using k-center greedy algorithm. Reference: Sener & Savarese, "Active Learning for Convolutional Neural Networks: A Core-Set Approach" (ICLR 2018) Args: batch_size: Number of samples to select """ def __init__(self, batch_size: int = 100): self.batch_size = batch_size def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, x_labeled: Optional[torch.Tensor] = None, features_unlabeled: Optional[torch.Tensor] = None, features_labeled: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Select core-set using greedy k-center. Args: model: Model for feature extraction x_unlabeled: Unlabeled data x_labeled: Labeled data (optional) features_unlabeled: Precomputed features for unlabeled data features_labeled: Precomputed features for labeled data Returns: Indices of selected samples """ # Extract features if features_unlabeled is None: with torch.no_grad(): model.eval() features_unlabeled = model(x_unlabeled) if features_unlabeled.dim() > 2: features_unlabeled = features_unlabeled.flatten(1) if x_labeled is not None and features_labeled is None: with torch.no_grad(): model.eval() features_labeled = model(x_labeled) if features_labeled.dim() > 2: features_labeled = features_labeled.flatten(1) # Greedy k-center if features_labeled is not None: # Start with labeled data as centers centers = features_labeled else: # Start with empty set centers = torch.empty( 0, features_unlabeled.size(1), device=features_unlabeled.device ) selected = [] for _ in range(min(self.batch_size, len(x_unlabeled))): if len(centers) == 0: # Select random first point idx = torch.randint(len(features_unlabeled), (1,)).item() else: # Compute distances to nearest center dists = torch.cdist(features_unlabeled, centers).min(dim=1).values # Select point farthest from centers idx = dists.argmax().item() selected.append(idx) # Add to centers centers = torch.cat([centers, features_unlabeled[idx : idx + 1]], dim=0) return torch.tensor(selected, device=x_unlabeled.device) class BadgeSampling: """ BADGE (Batch Active learning by Diverse Gradient Embeddings). Combines gradient-based embeddings with k-MEANS++ for diverse batch selection. Reference: Ash et al., "Deep Batch Active Learning by Diverse, Uncertain Gradient Lower Bounds" (ICLR 2020) Args: batch_size: Number of samples to select """ def __init__(self, batch_size: int = 100): self.batch_size = batch_size def _compute_gradient_embeddings( self, model: torch.nn.Module, x: torch.Tensor, ) -> torch.Tensor: """Compute gradient embeddings for samples.""" model.eval() embeddings = [] for i in range(len(x)): x_i = x[i : i + 1].requires_grad_(True) # Forward pass output = model(x_i) probs = F.softmax(output, dim=-1) # Compute gradients for each class grad_embeds = [] for c in range(output.size(-1)): if x_i.grad is not None: x_i.grad.zero_() # Backprop for class c probs[0, c].backward(retain_graph=True) # Get gradient if x_i.grad is not None: grad = x_i.grad.flatten().detach() grad_embeds.append(grad * probs[0, c].item()) # Concatenate gradients embedding = ( torch.cat(grad_embeds) if grad_embeds else torch.zeros(1, device=x.device) ) embeddings.append(embedding) return torch.stack(embeddings) def query( self, model: torch.nn.Module, x_unlabeled: torch.Tensor, ) -> torch.Tensor: """ Select batch using BADGE. Args: model: Model x_unlabeled: Unlabeled data Returns: Indices of selected samples """ # Compute gradient embeddings embeddings = self._compute_gradient_embeddings(model, x_unlabeled) # k-MEANS++ initialization for diversity selected = [] available = torch.arange(len(embeddings)) # First point: random idx = torch.randint(len(embeddings), (1,)).item() selected.append(idx) # Remaining points: proportional to distance from nearest selected for _ in range(min(self.batch_size - 1, len(embeddings) - 1)): # Compute distances to nearest selected point dists = ( torch.cdist(embeddings[available], embeddings[selected]) .min(dim=1) .values ) # Sample proportionally to squared distance probs = (dists**2) / (dists**2).sum() idx = available[torch.multinomial(probs, 1).item()].item() selected.append(idx) # Update available available = torch.tensor( [i for i in range(len(embeddings)) if i not in selected], device=x_unlabeled.device, ) return torch.tensor(selected, device=x_unlabeled.device) class QueryByCommittee: """ Query by Committee (QBC). Uses disagreement among an ensemble of models to select informative samples. Reference: Seung et al., "Query by Committee" (COLT 1992) Args: models: List of committee members batch_size: Number of samples to select disagreement: Disagreement measure ('vote_entropy' or 'kl') """ def __init__( self, models: List[torch.nn.Module], batch_size: int = 100, disagreement: str = "vote_entropy", ): self.models = models self.batch_size = batch_size self.disagreement = disagreement @torch.no_grad() def query( self, x_unlabeled: torch.Tensor, ) -> torch.Tensor: """ Query using committee disagreement. Args: x_unlabeled: Unlabeled data Returns: Indices of selected samples """ # Collect predictions from all committee members predictions = [] for model in self.models: model.eval() logits = model(x_unlabeled) probs = F.softmax(logits, dim=-1) predictions.append(probs) predictions = torch.stack(predictions) # (num_models, batch_size, num_classes) if self.disagreement == "vote_entropy": # Vote entropy: entropy of vote distribution votes = predictions.argmax(dim=-1) # (num_models, batch_size) # Count votes for each class num_classes = predictions.size(-1) vote_counts = torch.zeros( len(x_unlabeled), num_classes, device=x_unlabeled.device ) for i in range(len(self.models)): for j in range(len(x_unlabeled)): vote_counts[j, votes[i, j]] += 1 # Compute entropy of vote distribution vote_probs = vote_counts / len(self.models) scores = -(vote_probs * torch.log(vote_probs + 1e-10)).sum(dim=-1) elif self.disagreement == "kl": # Average KL divergence from mean mean_probs = predictions.mean(dim=0) scores = torch.zeros(len(x_unlabeled), device=x_unlabeled.device) for i in range(len(self.models)): kl = ( predictions[i] * ( torch.log(predictions[i] + 1e-10) - torch.log(mean_probs + 1e-10) ) ).sum(dim=-1) scores += kl scores /= len(self.models) else: raise ValueError(f"Unknown disagreement measure: {self.disagreement}") # Select top-k _, indices = torch.topk(scores, k=min(self.batch_size, len(scores))) return indices __all__ = [ "UncertaintySampling", "DiversitySampling", "CoreSetSelection", "BadgeSampling", "QueryByCommittee", ]