incerto.shift.LabelShiftDetector#
- class incerto.shift.LabelShiftDetector(num_classes, calibrated=False)[source]#
Bases:
objectBlack-Box Shift Detection for label shift.
Detects and quantifies label shift (prior probability shift) using confusion matrix estimation.
- Reference:
Lipton et al., “Detecting and Correcting for Label Shift with Black Box Predictors” (ICML 2018)
- Parameters:
Methods
__init__(num_classes[, calibrated])compute_shift_magnitude(model, target_loader)Compute magnitude of label shift.
estimate_target_distribution(model, ...)Estimate target label distribution.
fit(model, source_loader, validation_loader)Fit label shift detector.
load(path[, num_classes, calibrated])Load label shift detector from file.
load_state_dict(state)Load label shift detector state.
save(path)Save label shift detector state.
Save label shift detector state.
- fit(model, source_loader, validation_loader)[source]#
Fit label shift detector.
- Parameters:
model (
Module) – Trained classifiersource_loader (
DataLoader) – Source domain data with labelsvalidation_loader (
DataLoader) – Validation set from source domain
- estimate_target_distribution(model, target_loader)[source]#
Estimate target label distribution.
- Parameters:
model (
Module) – Trained classifiertarget_loader (
DataLoader) – Target domain data (no labels needed)
- Return type:
- Returns:
Estimated target label distribution
- compute_shift_magnitude(model, target_loader, metric='tvd')[source]#
Compute magnitude of label shift.
- Parameters:
model (
Module) – Trained classifiertarget_loader (
DataLoader) – Target domain datametric (
str) – Metric to use (‘tvd’, ‘kl’, ‘l2’)
- Return type:
- Returns:
Shift magnitude