Distribution Shift Detection Guide#
Distribution shift detection identifies when deployment data differs from training data. Critical for maintaining model performance in production.
Why Shift Detection Matters#
Model performance degrades when data distribution changes:
- Covariate shift: P(X) changes, P(Y|X) stays same
Example: Camera angle changes but objects remain same
- Label shift: P(Y) changes, P(X|Y) stays same
Example: Disease prevalence changes seasonally
- Concept drift: P(Y|X) changes
Example: User preferences change over time
Detecting shifts enables proactive retraining and monitoring.
Two-Sample Tests#
Test if two datasets come from same distribution.
Maximum Mean Discrepancy (MMD)#
Best for: Non-parametric, works in any dimension
Kernel-based test measuring distance between distributions:
from incerto.shift import MMDShiftDetector
detector = MMDShiftDetector(sigma=1.0)
# Fit on reference (training) data
detector.fit(reference_loader)
# Compute shift score
shift_score = detector.score(deployment_loader)
print(f"MMD score: {shift_score:.4f}")
# Higher score = more shift detected
# Save for monitoring
detector.save('mmd_detector.pt')
- Advantages:
Non-parametric
Theoretically sound
Works well empirically
- Disadvantages:
Requires choosing kernel bandwidth (sigma)
O(n²) complexity (can be approximated)
Energy Distance#
Best for: Alternative to MMD, similar properties
from incerto.shift import EnergyShiftDetector
detector = EnergyShiftDetector()
detector.fit(reference_loader)
shift_score = detector.score(deployment_loader)
Kolmogorov-Smirnov Test#
Best for: 1D features, interpretable
Per-feature univariate test:
from incerto.shift import KSShiftDetector
detector = KSShiftDetector()
detector.fit(reference_loader)
# Returns maximum KS statistic across features
shift_score = detector.score(deployment_loader)
- Advantages:
Simple, interpretable
Exact p-values available
- Disadvantages:
Only tests marginal distributions
Misses multivariate dependencies
Classifier Two-Sample Test#
Best for: High dimensions, model-agnostic
Train classifier to distinguish reference from test data:
from incerto.shift import ClassifierShiftDetector
detector = ClassifierShiftDetector()
detector.fit(reference_loader)
# Fit classifier and compute shift score
shift_score = detector.score(deployment_loader)
# Score is how well classifier can distinguish datasets
# 0.5 = no shift (random guessing)
# 1.0 = complete shift (perfect separation)
Intuition: If distributions differ, classifier can tell them apart
Label Shift Detection#
Detect and quantify label distribution changes:
from incerto.shift import LabelShiftDetector
detector = LabelShiftDetector(num_classes=10)
# Fit on source data
detector.fit(model, source_loader, validation_loader)
# Estimate target label distribution
target_dist = detector.estimate_target_distribution(
model,
target_loader
)
print(f"Source dist: {detector.source_label_dist}")
print(f"Target dist: {target_dist}")
# Quantify shift magnitude
shift_mag = detector.compute_shift_magnitude(
model,
target_loader,
metric='tvd' # Total variation distance
)
print(f"Label shift magnitude: {shift_mag:.4f}")
Importance Weighting#
Adapt to covariate shift via importance weighting:
from incerto.shift import ImportanceWeightingShift
# Estimate importance weights
iw = ImportanceWeightingShift(method='logistic', alpha=0.01)
iw.fit(source_features, target_features)
# Get weights for source samples
weights = iw.compute_weights(source_features)
# Use in training
for x, y in source_loader:
loss = model(x, y)
weighted_loss = iw.weighted_loss(loss, weights)
weighted_loss.backward()
Complete Monitoring Workflow#
from incerto.shift import MMDShiftDetector
import time
# 1. Train model and save reference data
model = train_model(train_loader)
detector = MMDShiftDetector(sigma=1.0)
detector.fit(train_loader) # Reference distribution
# Save baseline
detector.save('shift_detector.pt')
# 2. Production monitoring loop
shift_scores = []
timestamps = []
while True:
# Get recent production data
prod_data = get_recent_production_data()
# Detect shift
shift_score = detector.score(prod_data)
shift_scores.append(shift_score)
timestamps.append(time.time())
# Alert if shift detected
if shift_score > threshold:
print(f"⚠️ Distribution shift detected!")
print(f"Shift score: {shift_score:.4f}")
# Trigger retraining, alert team, etc.
# Wait before next check
time.sleep(3600) # Check hourly
Evaluation Metrics#
- Shift score:
Method-specific score quantifying shift magnitude
- Statistical tests:
P-value for hypothesis test (no shift vs. shift)
from incerto.shift import mmd_test
result = mmd_test(reference_data, test_data)
print(f"Test statistic: {result['statistic']:.4f}")
print(f"P-value: {result['p_value']:.4f}")
if result['p_value'] < 0.05:
print("Significant shift detected (p < 0.05)")
- Area Under Shift Curve:
For gradual shifts, plot shift score over time
Best Practices#
- Establish baseline on clean training data
Reference distribution should be well-defined
- Monitor continuously
Shifts can be gradual or sudden
- Use multiple detectors
Different methods catch different shift types
- Set meaningful thresholds
Based on validation data, not test data
- Combine with performance monitoring
Shift detection + accuracy drop = retrain
- Save detector state
Allows consistent monitoring over time
- Consider feature-level monitoring
Track individual feature distributions
Common Shift Scenarios#
- Natural drift:
Gradual changes over time (user preferences, seasonality)
- Sudden shifts:
Abrupt changes (new data source, hardware change)
- Adversarial shifts:
Intentional distribution changes (attacks)
- Data quality issues:
Preprocessing bugs, sensor failures
References#
Gretton et al., “A Kernel Two-Sample Test” (JMLR 2012)
Lipton et al., “Detecting and Correcting for Label Shift” (ICML 2018)
Rabanser et al., “Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift” (NeurIPS 2019)
Quinonero-Candela et al., “Dataset Shift in Machine Learning” (MIT Press 2009)
See Also#
Distribution Shift Detection - Complete API reference
Out-of-Distribution Detection Guide - Out-of-distribution detection
Calibration Guide - Model calibration under shift