"""
Query strategies for active learning.
Strategies for selecting batches of samples to label, combining
acquisition functions with diversity and batch considerations.
"""
from __future__ import annotations
from typing import List, Optional, Callable
import torch
import torch.nn.functional as F
import numpy as np
from .acquisition import BaseAcquisition, EntropyAcquisition
[docs]
class UncertaintySampling:
"""
Uncertainty-based sampling strategy.
Selects samples with highest acquisition scores.
Args:
acquisition_fn: Acquisition function to use
batch_size: Number of samples to select per query
Example:
>>> acquisition = BALDAcquisition(num_samples=20)
>>> strategy = UncertaintySampling(acquisition, batch_size=100)
>>> indices = strategy.query(model, unlabeled_data)
"""
[docs]
def __init__(
self,
acquisition_fn: BaseAcquisition,
batch_size: int = 100,
):
self.acquisition_fn = acquisition_fn
self.batch_size = batch_size
[docs]
def query(
self,
model: torch.nn.Module,
x_unlabeled: torch.Tensor,
) -> torch.Tensor:
"""
Query samples based on uncertainty.
Args:
model: Trained model
x_unlabeled: Unlabeled data (N, *)
Returns:
Indices of selected samples (batch_size,)
"""
# Compute acquisition scores
scores = self.acquisition_fn.score(model, x_unlabeled)
# Select top-k samples
_, indices = torch.topk(scores, k=min(self.batch_size, len(scores)))
return indices
[docs]
class DiversitySampling:
"""
Diversity-based sampling with uncertainty.
Balances uncertainty with diversity to avoid selecting redundant samples.
Reference:
Brinker, "Incorporating Diversity in Active Learning" (ICML 2003)
Args:
acquisition_fn: Acquisition function
batch_size: Number of samples to select
diversity_weight: Weight for diversity term (0-1)
"""
[docs]
def __init__(
self,
acquisition_fn: BaseAcquisition,
batch_size: int = 100,
diversity_weight: float = 0.5,
):
self.acquisition_fn = acquisition_fn
self.batch_size = batch_size
self.diversity_weight = diversity_weight
[docs]
def query(
self,
model: torch.nn.Module,
x_unlabeled: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Query samples balancing uncertainty and diversity.
Args:
model: Trained model
x_unlabeled: Unlabeled data
features: Optional precomputed features for diversity computation
Returns:
Indices of selected samples
"""
# Compute uncertainty scores
uncertainty_scores = self.acquisition_fn.score(model, x_unlabeled)
# Extract features if not provided
if features is None:
with torch.no_grad():
model.eval()
# Use model embeddings as features
features = model(x_unlabeled)
if features.dim() > 2:
features = features.flatten(1)
# Normalize scores
uncertainty_scores = (uncertainty_scores - uncertainty_scores.min()) / (
uncertainty_scores.max() - uncertainty_scores.min() + 1e-10
)
# Greedy selection with diversity
selected = []
available = torch.arange(len(x_unlabeled))
for _ in range(min(self.batch_size, len(x_unlabeled))):
if len(selected) == 0:
# First sample: select most uncertain
idx = uncertainty_scores.argmax()
selected.append(idx.item())
else:
# Compute diversity scores
selected_features = features[selected]
diversity_scores = (
torch.cdist(features[available], selected_features)
.min(dim=1)
.values
)
# Normalize diversity
diversity_scores = (diversity_scores - diversity_scores.min()) / (
diversity_scores.max() - diversity_scores.min() + 1e-10
)
# Combined score
combined = (1 - self.diversity_weight) * uncertainty_scores[
available
] + self.diversity_weight * diversity_scores
# Select best
best_idx = combined.argmax()
selected.append(available[best_idx].item())
# Update available indices
available = torch.tensor(
[i for i in range(len(x_unlabeled)) if i not in selected],
device=x_unlabeled.device,
)
return torch.tensor(selected, device=x_unlabeled.device)
class CoreSetSelection:
"""
Core-Set selection for active learning.
Selects samples that best represent the overall data distribution
using k-center greedy algorithm.
Reference:
Sener & Savarese, "Active Learning for Convolutional Neural Networks:
A Core-Set Approach" (ICLR 2018)
Args:
batch_size: Number of samples to select
"""
def __init__(self, batch_size: int = 100):
self.batch_size = batch_size
def query(
self,
model: torch.nn.Module,
x_unlabeled: torch.Tensor,
x_labeled: Optional[torch.Tensor] = None,
features_unlabeled: Optional[torch.Tensor] = None,
features_labeled: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Select core-set using greedy k-center.
Args:
model: Model for feature extraction
x_unlabeled: Unlabeled data
x_labeled: Labeled data (optional)
features_unlabeled: Precomputed features for unlabeled data
features_labeled: Precomputed features for labeled data
Returns:
Indices of selected samples
"""
# Extract features
if features_unlabeled is None:
with torch.no_grad():
model.eval()
features_unlabeled = model(x_unlabeled)
if features_unlabeled.dim() > 2:
features_unlabeled = features_unlabeled.flatten(1)
if x_labeled is not None and features_labeled is None:
with torch.no_grad():
model.eval()
features_labeled = model(x_labeled)
if features_labeled.dim() > 2:
features_labeled = features_labeled.flatten(1)
# Greedy k-center
if features_labeled is not None:
# Start with labeled data as centers
centers = features_labeled
else:
# Start with empty set
centers = torch.empty(
0, features_unlabeled.size(1), device=features_unlabeled.device
)
selected = []
for _ in range(min(self.batch_size, len(x_unlabeled))):
if len(centers) == 0:
# Select random first point
idx = torch.randint(len(features_unlabeled), (1,)).item()
else:
# Compute distances to nearest center
dists = torch.cdist(features_unlabeled, centers).min(dim=1).values
# Select point farthest from centers
idx = dists.argmax().item()
selected.append(idx)
# Add to centers
centers = torch.cat([centers, features_unlabeled[idx : idx + 1]], dim=0)
return torch.tensor(selected, device=x_unlabeled.device)
class BadgeSampling:
"""
BADGE (Batch Active learning by Diverse Gradient Embeddings).
Combines gradient-based embeddings with k-MEANS++ for diverse batch selection.
Reference:
Ash et al., "Deep Batch Active Learning by Diverse, Uncertain Gradient
Lower Bounds" (ICLR 2020)
Args:
batch_size: Number of samples to select
"""
def __init__(self, batch_size: int = 100):
self.batch_size = batch_size
def _compute_gradient_embeddings(
self,
model: torch.nn.Module,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute gradient embeddings for samples."""
model.eval()
embeddings = []
for i in range(len(x)):
x_i = x[i : i + 1].requires_grad_(True)
# Forward pass
output = model(x_i)
probs = F.softmax(output, dim=-1)
# Compute gradients for each class
grad_embeds = []
for c in range(output.size(-1)):
if x_i.grad is not None:
x_i.grad.zero_()
# Backprop for class c
probs[0, c].backward(retain_graph=True)
# Get gradient
if x_i.grad is not None:
grad = x_i.grad.flatten().detach()
grad_embeds.append(grad * probs[0, c].item())
# Concatenate gradients
embedding = (
torch.cat(grad_embeds)
if grad_embeds
else torch.zeros(1, device=x.device)
)
embeddings.append(embedding)
return torch.stack(embeddings)
def query(
self,
model: torch.nn.Module,
x_unlabeled: torch.Tensor,
) -> torch.Tensor:
"""
Select batch using BADGE.
Args:
model: Model
x_unlabeled: Unlabeled data
Returns:
Indices of selected samples
"""
# Compute gradient embeddings
embeddings = self._compute_gradient_embeddings(model, x_unlabeled)
# k-MEANS++ initialization for diversity
selected = []
available = torch.arange(len(embeddings))
# First point: random
idx = torch.randint(len(embeddings), (1,)).item()
selected.append(idx)
# Remaining points: proportional to distance from nearest selected
for _ in range(min(self.batch_size - 1, len(embeddings) - 1)):
# Compute distances to nearest selected point
dists = (
torch.cdist(embeddings[available], embeddings[selected])
.min(dim=1)
.values
)
# Sample proportionally to squared distance
probs = (dists**2) / (dists**2).sum()
idx = available[torch.multinomial(probs, 1).item()].item()
selected.append(idx)
# Update available
available = torch.tensor(
[i for i in range(len(embeddings)) if i not in selected],
device=x_unlabeled.device,
)
return torch.tensor(selected, device=x_unlabeled.device)
class QueryByCommittee:
"""
Query by Committee (QBC).
Uses disagreement among an ensemble of models to select informative samples.
Reference:
Seung et al., "Query by Committee" (COLT 1992)
Args:
models: List of committee members
batch_size: Number of samples to select
disagreement: Disagreement measure ('vote_entropy' or 'kl')
"""
def __init__(
self,
models: List[torch.nn.Module],
batch_size: int = 100,
disagreement: str = "vote_entropy",
):
self.models = models
self.batch_size = batch_size
self.disagreement = disagreement
@torch.no_grad()
def query(
self,
x_unlabeled: torch.Tensor,
) -> torch.Tensor:
"""
Query using committee disagreement.
Args:
x_unlabeled: Unlabeled data
Returns:
Indices of selected samples
"""
# Collect predictions from all committee members
predictions = []
for model in self.models:
model.eval()
logits = model(x_unlabeled)
probs = F.softmax(logits, dim=-1)
predictions.append(probs)
predictions = torch.stack(predictions) # (num_models, batch_size, num_classes)
if self.disagreement == "vote_entropy":
# Vote entropy: entropy of vote distribution
votes = predictions.argmax(dim=-1) # (num_models, batch_size)
# Count votes for each class
num_classes = predictions.size(-1)
vote_counts = torch.zeros(
len(x_unlabeled), num_classes, device=x_unlabeled.device
)
for i in range(len(self.models)):
for j in range(len(x_unlabeled)):
vote_counts[j, votes[i, j]] += 1
# Compute entropy of vote distribution
vote_probs = vote_counts / len(self.models)
scores = -(vote_probs * torch.log(vote_probs + 1e-10)).sum(dim=-1)
elif self.disagreement == "kl":
# Average KL divergence from mean
mean_probs = predictions.mean(dim=0)
scores = torch.zeros(len(x_unlabeled), device=x_unlabeled.device)
for i in range(len(self.models)):
kl = (
predictions[i]
* (
torch.log(predictions[i] + 1e-10)
- torch.log(mean_probs + 1e-10)
)
).sum(dim=-1)
scores += kl
scores /= len(self.models)
else:
raise ValueError(f"Unknown disagreement measure: {self.disagreement}")
# Select top-k
_, indices = torch.topk(scores, k=min(self.batch_size, len(scores)))
return indices
__all__ = [
"UncertaintySampling",
"DiversitySampling",
"CoreSetSelection",
"BadgeSampling",
"QueryByCommittee",
]