Source code for incerto.shift.methods

"""
incerto.shift_detection.methods
===============================

Fast, sklearn-style wrappers around common shift-detection techniques.
Each detector exposes two methods:

    .fit(reference_loader)           # builds any reference statistics
    .score(test_loader) -> float     # returns a scalar shift score
"""

from __future__ import annotations
from typing import Optional

import torch
from torch.utils.data import DataLoader
from scipy import stats
from incerto.core.utils import pairwise_squared_euclidean

from .base import BaseShiftDetector
from . import metrics


# ------------------------------------------------------------------------- #
#   Non-parametric two-sample tests
# ------------------------------------------------------------------------- #
[docs] class MMDShiftDetector(BaseShiftDetector): r"""Kernel Maximum Mean Discrepancy with Gaussian (RBF) kernel. Computes the biased MMD estimator which includes diagonal terms. For large sample sizes, this converges to the true MMD. Reference: Gretton et al., "A Kernel Two-Sample Test" (JMLR 2012) Args: sigma: RBF kernel bandwidth parameter (default: 1.0) """
[docs] def __init__(self, sigma: float = 1.0) -> None: self.sigma = sigma
def _rbf(self, x, y): return torch.exp(-pairwise_squared_euclidean(x, y) / (2 * self.sigma**2)) def _compute(self, test: torch.Tensor) -> float: x, y = self._reference, test k_xx = self._rbf(x, x).mean() k_yy = self._rbf(y, y).mean() k_xy = self._rbf(x, y).mean() return (k_xx + k_yy - 2 * k_xy).item()
[docs] def state_dict(self) -> dict: """Save MMD detector state.""" state = super().state_dict() state["sigma"] = self.sigma return state
[docs] def load_state_dict(self, state: dict) -> None: """Load MMD detector state.""" super().load_state_dict(state) self.sigma = state["sigma"]
def __repr__(self) -> str: fitted = hasattr(self, "_reference") n_samples = len(self._reference) if fitted else "not fitted" return f"MMDShiftDetector(sigma={self.sigma}, n_ref_samples={n_samples})"
[docs] class EnergyShiftDetector(BaseShiftDetector): """Energy distance – Szekely & Rizzo, 2013.""" def _compute(self, test: torch.Tensor) -> float: x, y = self._reference, test return metrics.energy_distance(x, y) # re-use metric def __repr__(self) -> str: fitted = hasattr(self, "_reference") n_samples = len(self._reference) if fitted else "not fitted" return f"EnergyShiftDetector(n_ref_samples={n_samples})"
[docs] class KSShiftDetector(BaseShiftDetector): """One-dimensional Kolmogorov–Smirnov test (per feature, max statistic).""" def _compute(self, test: torch.Tensor) -> float: return max( stats.ks_2samp(x.cpu().numpy(), test[:, i].cpu().numpy()).statistic for i, x in enumerate(self._reference.T) ) def __repr__(self) -> str: fitted = hasattr(self, "_reference") if fitted: n_samples = len(self._reference) n_features = self._reference.shape[1] if self._reference.dim() > 1 else 1 return ( f"KSShiftDetector(n_ref_samples={n_samples}, n_features={n_features})" ) return "KSShiftDetector(not fitted)"
# ------------------------------------------------------------------------- # # Black-box shift detectors (BBSD, classifier-based) # ------------------------------------------------------------------------- #
[docs] class ClassifierShiftDetector(BaseShiftDetector): r"""Train a logistic regression to separate reference and test. * Lipton et al., 2018 (Black Box Shift Detection) """
[docs] def __init__(self, clf_factory=None, device: Optional[str] = None) -> None: from sklearn.linear_model import LogisticRegression self.clf = clf_factory() if clf_factory else LogisticRegression(max_iter=1000) self.device = device
def _compute(self, test: torch.Tensor) -> float: import numpy as np from sklearn.model_selection import cross_val_predict, StratifiedKFold X_ref = self._reference.cpu().numpy() X_test = test.cpu().numpy() X = np.concatenate([X_ref, X_test], axis=0) y = np.concatenate([np.zeros(len(X_ref)), np.ones(len(X_test))]) # Use cross-validation to avoid train/test leakage n_splits = max(2, min(5, len(X_ref), len(X_test))) cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0) proba = cross_val_predict(self.clf, X, y, cv=cv, method="predict_proba")[:, 1] test_proba = proba[len(X_ref) :] # Mean output probability should be ~0.5 under no shift return abs(float(test_proba.mean()) - 0.5) * 2
[docs] def state_dict(self) -> dict: """Save classifier shift detector state. Note: The classifier itself is not serialized because it is re-fitted on every call to score(). A fresh LogisticRegression will be used after loading. Pass clf_factory again if you need a different classifier. """ state = super().state_dict() state["device"] = self.device return state
[docs] def load_state_dict(self, state: dict) -> None: """Load classifier shift detector state.""" from sklearn.linear_model import LogisticRegression super().load_state_dict(state) self.device = state.get("device") if not hasattr(self, "clf"): self.clf = LogisticRegression(max_iter=1000)
def __repr__(self) -> str: fitted = hasattr(self, "_reference") n_samples = len(self._reference) if fitted else "not fitted" return f"ClassifierShiftDetector(n_ref_samples={n_samples})"
# Alias for BBSD BBSDDetector = ClassifierShiftDetector
[docs] class LabelShiftDetector: """ Black-Box Shift Detection for label shift. Detects and quantifies label shift (prior probability shift) using confusion matrix estimation. Reference: Lipton et al., "Detecting and Correcting for Label Shift with Black Box Predictors" (ICML 2018) Args: num_classes: Number of classes calibrated: Whether predictions are calibrated """
[docs] def __init__(self, num_classes: int, calibrated: bool = False): self.num_classes = num_classes self.calibrated = calibrated self.source_label_dist = None self.confusion_matrix = None
[docs] def fit( self, model: torch.nn.Module, source_loader: DataLoader, validation_loader: DataLoader, ): """ Fit label shift detector. Args: model: Trained classifier source_loader: Source domain data with labels validation_loader: Validation set from source domain """ model.eval() # Detect model device device = next(model.parameters()).device # Estimate source label distribution all_labels = [] for _, y in source_loader: all_labels.append(y) all_labels = torch.cat(all_labels) label_counts = torch.bincount(all_labels, minlength=self.num_classes) self.source_label_dist = label_counts.float() / label_counts.sum() # Estimate confusion matrix on validation set confusion = torch.zeros(self.num_classes, self.num_classes) with torch.no_grad(): for x, y in validation_loader: outputs = model(x.to(device)) preds = outputs.argmax(dim=-1).cpu() for true_label in range(self.num_classes): mask = y == true_label if mask.sum() > 0: pred_counts = torch.bincount( preds[mask], minlength=self.num_classes ) confusion[true_label] += pred_counts.float() # Normalize rows (C[i,j] = P(pred=j | true=i)) row_sums = confusion.sum(dim=1, keepdim=True) row_sums[row_sums == 0] = 1 # Avoid division by zero self.confusion_matrix = confusion / row_sums
[docs] def estimate_target_distribution( self, model: torch.nn.Module, target_loader: DataLoader, ) -> torch.Tensor: """ Estimate target label distribution. Args: model: Trained classifier target_loader: Target domain data (no labels needed) Returns: Estimated target label distribution """ if self.confusion_matrix is None: raise RuntimeError("Must call fit() first") model.eval() # Detect model device device = next(model.parameters()).device # Get predictions on target data all_preds = [] with torch.no_grad(): for x, _ in target_loader: outputs = model(x.to(device)) preds = outputs.argmax(dim=-1).cpu() all_preds.append(preds) all_preds = torch.cat(all_preds) # Compute empirical prediction distribution pred_counts = torch.bincount(all_preds, minlength=self.num_classes) pred_dist = pred_counts.float() / pred_counts.sum() # Solve: pred_dist = C^T @ target_dist # target_dist = (C^T)^{-1} @ pred_dist try: target_dist = torch.linalg.solve(self.confusion_matrix.T, pred_dist) # Project to simplex (ensure non-negative and sum to 1) target_dist = torch.clamp(target_dist, min=0) target_dist = target_dist / target_dist.sum() except Exception: # Fallback to least squares if matrix is singular target_dist = torch.linalg.lstsq( self.confusion_matrix.T, pred_dist ).solution target_dist = torch.clamp(target_dist, min=0) target_dist = target_dist / target_dist.sum() return target_dist
[docs] def compute_shift_magnitude( self, model: torch.nn.Module, target_loader: DataLoader, metric: str = "tvd", ) -> float: """ Compute magnitude of label shift. Args: model: Trained classifier target_loader: Target domain data metric: Metric to use ('tvd', 'kl', 'l2') Returns: Shift magnitude """ target_dist = self.estimate_target_distribution(model, target_loader) if metric == "tvd": # Total variation distance return 0.5 * torch.abs(target_dist - self.source_label_dist).sum().item() elif metric == "kl": # KL divergence return ( ( target_dist * torch.log( (target_dist + 1e-10) / (self.source_label_dist + 1e-10) ) ) .sum() .item() ) elif metric == "l2": # L2 distance return torch.norm(target_dist - self.source_label_dist, p=2).item() else: raise ValueError(f"Unknown metric: {metric}")
[docs] def state_dict(self) -> dict: """Save label shift detector state.""" return { "num_classes": self.num_classes, "calibrated": self.calibrated, "source_label_dist": self.source_label_dist, "confusion_matrix": self.confusion_matrix, }
[docs] def load_state_dict(self, state: dict) -> None: """Load label shift detector state.""" from ..exceptions import SerializationError try: self.num_classes = state["num_classes"] self.calibrated = state["calibrated"] self.source_label_dist = state["source_label_dist"] self.confusion_matrix = state["confusion_matrix"] except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e
[docs] def save(self, path: str) -> None: """Save label shift detector state.""" from ..exceptions import SerializationError try: torch.save(self.state_dict(), path) except Exception as e: raise SerializationError(f"Failed to save to {path}: {e}") from e
[docs] @classmethod def load( cls, path: str, num_classes: int = 0, calibrated: bool = False ) -> "LabelShiftDetector": """Load label shift detector from file. Args: path: File path to load the state from. num_classes: Ignored; restored from saved state. Kept for backward compatibility. calibrated: Ignored; restored from saved state. Kept for backward compatibility. """ from ..exceptions import SerializationError try: state = torch.load(path, weights_only=True) detector = cls( state.get("num_classes", num_classes), state.get("calibrated", calibrated), ) detector.load_state_dict(state) return detector except Exception as e: raise SerializationError(f"Failed to load from {path}: {e}") from e
def __repr__(self) -> str: fitted = self.source_label_dist is not None return f"LabelShiftDetector(num_classes={self.num_classes}, calibrated={self.calibrated}, fitted={fitted})"
[docs] class ImportanceWeightingShift: """ Importance weighting for covariate shift adaptation. Estimates density ratio w(x) = p_target(x) / p_source(x) and uses it to re-weight training samples. Reference: Sugiyama et al., "Direct Importance Estimation with Model Selection" (NIPS 2007) Args: method: Estimation method ('kernel', 'logistic', 'kliep') alpha: Regularization parameter """
[docs] def __init__(self, method: str = "logistic", alpha: float = 0.01): self.method = method self.alpha = alpha self.weights_model = None
[docs] def fit( self, source_features: torch.Tensor, target_features: torch.Tensor, ): """ Estimate importance weights. Args: source_features: Features from source domain (N_s, D) target_features: Features from target domain (N_t, D) """ import numpy as np if self.method == "logistic": # Train logistic regression to discriminate source vs target from sklearn.linear_model import LogisticRegression X_source = source_features.cpu().numpy() X_target = target_features.cpu().numpy() X = np.concatenate([X_source, X_target], axis=0) y = np.concatenate([np.zeros(len(X_source)), np.ones(len(X_target))]) clf = LogisticRegression(C=1.0 / self.alpha, max_iter=1000) clf.fit(X, y) self.weights_model = clf elif self.method == "kernel": # Kernel mean matching (KMM) self._fit_kernel_weights(source_features, target_features) else: raise ValueError(f"Unknown method: {self.method}")
def _fit_kernel_weights( self, source_features: torch.Tensor, target_features: torch.Tensor, ): """Fit kernel-based importance weights.""" # Simplified KMM implementation from sklearn.metrics.pairwise import rbf_kernel X_s = source_features.cpu().numpy() X_t = target_features.cpu().numpy() n_s = len(X_s) n_t = len(X_t) # Compute kernel matrices with ridge regularization for stability import numpy as np from scipy.optimize import minimize K = rbf_kernel(X_s, X_s) K += 1e-6 * np.eye(n_s) # Ridge regularization kappa = rbf_kernel(X_s, X_t).mean(axis=1) # Solve QP problem (simplified) # min 0.5 * w^T K w - kappa^T w # s.t. w >= 0, mean(w) = 1 def objective(w): w = np.clip(w, 0, None) with np.errstate(invalid="ignore", over="ignore", divide="ignore"): val = 0.5 * w @ K @ w - kappa @ w + self.alpha * (w**2).sum() return float(np.nan_to_num(val, nan=1e10, posinf=1e10)) def constraint(w): return w.mean() - 1 constraints = {"type": "eq", "fun": constraint} max_weight = 10.0 * n_s / n_t # Upper bound to prevent explosion bounds = [(0, max_weight) for _ in range(n_s)] result = minimize( objective, np.ones(n_s), bounds=bounds, constraints=constraints, method="SLSQP", ) self.weights_model = torch.tensor(result.x, dtype=torch.float32)
[docs] def compute_weights( self, source_features: torch.Tensor, ) -> torch.Tensor: """ Compute importance weights for source samples. Args: source_features: Source domain features Returns: Importance weights """ if self.weights_model is None: raise RuntimeError("Must call fit() first") if self.method == "logistic": # w(x) = P(target|x) / P(source|x) # = p(x|target) / p(x|source) # ~ P(target|x) / (1 - P(target|x)) X = source_features.cpu().numpy() probs = self.weights_model.predict_proba(X)[:, 1] # P(target|x) weights = probs / (1 - probs + 1e-10) weights = torch.tensor(weights, dtype=torch.float32) # Clamp to prevent explosion when probs ≈ 1 weights = weights.clamp(max=100.0) # Normalize weights to sum to n weights = weights / weights.mean() return weights elif self.method == "kernel": if len(source_features) != len(self.weights_model): raise ValueError( f"Kernel weights are specific to the {len(self.weights_model)} " f"source samples used in fit(). Got {len(source_features)} samples. " f"Kernel method does not generalize to new samples." ) return self.weights_model else: raise ValueError(f"Unknown method: {self.method}")
[docs] def weighted_loss( self, loss: torch.Tensor, weights: torch.Tensor, ) -> torch.Tensor: """ Apply importance weights to loss. Args: loss: Per-sample losses weights: Importance weights Returns: Weighted average loss """ return (loss * weights).mean()
[docs] def state_dict(self) -> dict: """Save importance weighting state.""" from .._sklearn_io import serialize_logistic if self.weights_model is None: model_data = None elif isinstance(self.weights_model, torch.Tensor): model_data = {"_type": "tensor", "data": self.weights_model.tolist()} else: model_data = {"_type": "logistic", **serialize_logistic(self.weights_model)} return { "method": self.method, "alpha": self.alpha, "weights_model": model_data, }
[docs] def load_state_dict(self, state: dict) -> None: """Load importance weighting state.""" from .._sklearn_io import deserialize_logistic from ..exceptions import SerializationError try: self.method = state["method"] self.alpha = state["alpha"] model_data = state["weights_model"] if model_data is None: self.weights_model = None elif model_data.get("_type") == "tensor": self.weights_model = torch.tensor( model_data["data"], dtype=torch.float32 ) else: self.weights_model = deserialize_logistic(model_data) except Exception as e: raise SerializationError(f"Failed to load state: {e}") from e
[docs] def save(self, path: str) -> None: """Save importance weighting state.""" from ..exceptions import SerializationError try: torch.save(self.state_dict(), path) except Exception as e: raise SerializationError(f"Failed to save to {path}: {e}") from e
[docs] @classmethod def load( cls, path: str, method: str = "logistic", alpha: float = 0.01 ) -> "ImportanceWeightingShift": """Load importance weighting from file.""" from ..exceptions import SerializationError try: instance = cls(method, alpha) state = torch.load(path, weights_only=True) instance.load_state_dict(state) return instance except Exception as e: raise SerializationError(f"Failed to load from {path}: {e}") from e
def __repr__(self) -> str: fitted = self.weights_model is not None return f"ImportanceWeightingShift(method='{self.method}', alpha={self.alpha}, fitted={fitted})"