Source code for incerto.active.utils

"""
Utility functions for active learning.
"""

from __future__ import annotations
import logging
import torch
from typing import Tuple, Optional

logger = logging.getLogger(__name__)


[docs] def split_labeled_unlabeled( data: torch.Tensor, labels: Optional[torch.Tensor] = None, labeled_indices: Optional[torch.Tensor] = None, unlabeled_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Split data into labeled and unlabeled sets. Args: data: Full dataset ``(N, ...)`` labels: Labels for all data (N,), can have -1 for unlabeled labeled_indices: Indices of labeled samples unlabeled_indices: Indices of unlabeled samples Returns: Tuple of (x_labeled, x_unlabeled, y_labeled, labeled_indices) """ if labeled_indices is None and unlabeled_indices is None: if labels is None: raise ValueError("Must provide either labels or indices") # Use labels to determine split labeled_mask = labels != -1 labeled_indices = torch.where(labeled_mask)[0] unlabeled_indices = torch.where(~labeled_mask)[0] elif labeled_indices is None: # Compute labeled from unlabeled all_indices = torch.arange(len(data)) labeled_mask = ~torch.isin(all_indices, unlabeled_indices) labeled_indices = all_indices[labeled_mask] elif unlabeled_indices is None: # Compute unlabeled from labeled all_indices = torch.arange(len(data)) unlabeled_mask = ~torch.isin(all_indices, labeled_indices) unlabeled_indices = all_indices[unlabeled_mask] x_labeled = data[labeled_indices] x_unlabeled = data[unlabeled_indices] y_labeled = labels[labeled_indices] if labels is not None else None return x_labeled, x_unlabeled, y_labeled, labeled_indices
[docs] def compute_diversity_penalty( selected: torch.Tensor, features: torch.Tensor, method: str = "min_distance", ) -> torch.Tensor: """ Compute diversity penalty for selected samples. Args: selected: Indices of selected samples features: Feature representations of all samples method: Diversity measure ('min_distance', 'mean_distance', 'determinant') Returns: Diversity score. For 'min_distance' and 'mean_distance', higher values indicate more diversity. For 'determinant', lower (more negative) values indicate more diversity. """ if len(selected) == 0: return torch.tensor(0.0, device=features.device) selected_features = features[selected] if method == "min_distance": # Minimum pairwise distance (lower = less diverse) if len(selected) < 2: return torch.tensor(float("inf"), device=features.device) dists = torch.cdist(selected_features, selected_features) # Mask diagonal mask = torch.eye(len(selected), device=dists.device).bool() dists[mask] = float("inf") return dists.min() elif method == "mean_distance": # Mean pairwise distance if len(selected) < 2: return torch.tensor(float("inf"), device=features.device) dists = torch.cdist(selected_features, selected_features) mask = torch.eye(len(selected), device=dists.device).bool() dists[mask] = 0 return dists.sum() / (len(selected) * (len(selected) - 1)) elif method == "determinant": # Log determinant of covariance (higher = more diverse) if len(selected) < features.size(1): # Need at least d samples for d-dimensional features return torch.tensor(0.0, device=features.device) cov = torch.cov(selected_features.T) try: det = torch.linalg.det(cov) return -torch.log(det.clamp(min=1e-10)) # Negative log for penalty except Exception: return torch.tensor(0.0, device=features.device) else: raise ValueError(f"Unknown method: {method}")
[docs] def greedy_k_center( features: torch.Tensor, k: int, initial_centers: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Greedy k-center algorithm for diverse sample selection. Iteratively selects points that are farthest from all previously selected points. Args: features: Feature representations (N, D) k: Number of centers to select (clamped to N if larger) initial_centers: Optional initial centers Returns: Indices of selected centers """ n = len(features) k = min(k, n) # Can't select more than available selected = [] # Initialize with existing centers if provided if initial_centers is not None: centers = initial_centers else: centers = torch.empty(0, features.size(1), device=features.device) # Initialize distances if len(centers) > 0: min_dists = torch.cdist(features, centers).min(dim=1).values else: min_dists = torch.full((n,), float("inf"), device=features.device) for _ in range(k): # Select point with maximum distance to nearest center idx = min_dists.argmax().item() selected.append(idx) # Update centers centers = torch.cat([centers, features[idx : idx + 1]], dim=0) # Update minimum distances new_dists = torch.norm(features - features[idx : idx + 1], dim=1) min_dists = torch.minimum(min_dists, new_dists) return torch.tensor(selected, device=features.device)
[docs] def subsample_for_efficiency( data: torch.Tensor, max_samples: int = 10000, random_seed: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Subsample data for computational efficiency. Args: data: Full dataset max_samples: Maximum number of samples to keep random_seed: Random seed for reproducibility Returns: Tuple of (subsampled_data, indices) """ if len(data) <= max_samples: return data, torch.arange(len(data), device=data.device) if random_seed is not None: torch.manual_seed(random_seed) indices = torch.randperm(len(data), device=data.device)[:max_samples] indices = indices.sort().values # Sort for consistency return data[indices], indices
[docs] def active_learning_loop( model: torch.nn.Module, x_pool: torch.Tensor, y_pool: torch.Tensor, strategy, num_rounds: int = 10, initial_labeled: int = 100, train_fn: Optional[callable] = None, eval_fn: Optional[callable] = None, random_seed: Optional[int] = None, ) -> dict: """ Run a full active learning loop. Args: model: Model to train x_pool: Full data pool y_pool: Full labels strategy: Query strategy (batch size is controlled by the strategy) num_rounds: Number of active learning rounds initial_labeled: Number of initial labeled samples train_fn: Function to train model (model, x_train, y_train) eval_fn: Function to evaluate model (model, x_test, y_test) random_seed: Random seed Returns: Dictionary with results """ if random_seed is not None: torch.manual_seed(random_seed) # Initialize with random labeled samples n_total = len(x_pool) all_indices = torch.arange(n_total) labeled_indices = all_indices[torch.randperm(n_total)[:initial_labeled]] unlabeled_mask = ~torch.isin(all_indices, labeled_indices) unlabeled_indices = all_indices[unlabeled_mask] results = { "labeled_sizes": [], "accuracies": [], "selected_indices": [], } for round_idx in range(num_rounds): logger.info("Round %d/%d", round_idx + 1, num_rounds) # Get current split x_train = x_pool[labeled_indices] y_train = y_pool[labeled_indices] x_unlabeled = x_pool[unlabeled_indices] # Train model if train_fn is not None: train_fn(model, x_train, y_train) # Evaluate if eval_fn is not None: accuracy = eval_fn(model) results["accuracies"].append(accuracy) logger.info(" Accuracy: %.4f", accuracy) results["labeled_sizes"].append(len(labeled_indices)) # Query next batch if len(unlabeled_indices) == 0: logger.info("No more unlabeled samples") break query_indices = strategy.query(model, x_unlabeled) # Map query indices back to pool indices (ensure CPU for indexing) query_indices_cpu = query_indices.cpu() selected = unlabeled_indices[query_indices_cpu] results["selected_indices"].append(selected) # Update labeled/unlabeled sets labeled_indices = torch.cat([labeled_indices, selected]) unlabeled_mask = torch.ones(n_total, dtype=torch.bool) unlabeled_mask[labeled_indices] = False unlabeled_indices = torch.where(unlabeled_mask)[0] return results
__all__ = [ "split_labeled_unlabeled", "compute_diversity_penalty", "greedy_k_center", "subsample_for_efficiency", "active_learning_loop", ]