incerto.shift.LabelShiftDetector#

class incerto.shift.LabelShiftDetector(num_classes, calibrated=False)[source]#

Bases: object

Black-Box Shift Detection for label shift.

Detects and quantifies label shift (prior probability shift) using confusion matrix estimation.

Reference:

Lipton et al., “Detecting and Correcting for Label Shift with Black Box Predictors” (ICML 2018)

Parameters:
  • num_classes (int) – Number of classes

  • calibrated (bool) – Whether predictions are calibrated

__init__(num_classes, calibrated=False)[source]#
Parameters:
  • num_classes (int)

  • calibrated (bool)

Methods

__init__(num_classes[, calibrated])

compute_shift_magnitude(model, target_loader)

Compute magnitude of label shift.

estimate_target_distribution(model, ...)

Estimate target label distribution.

fit(model, source_loader, validation_loader)

Fit label shift detector.

load(path[, num_classes, calibrated])

Load label shift detector from file.

load_state_dict(state)

Load label shift detector state.

save(path)

Save label shift detector state.

state_dict()

Save label shift detector state.

__init__(num_classes, calibrated=False)[source]#
Parameters:
  • num_classes (int)

  • calibrated (bool)

fit(model, source_loader, validation_loader)[source]#

Fit label shift detector.

Parameters:
  • model (Module) – Trained classifier

  • source_loader (DataLoader) – Source domain data with labels

  • validation_loader (DataLoader) – Validation set from source domain

estimate_target_distribution(model, target_loader)[source]#

Estimate target label distribution.

Parameters:
  • model (Module) – Trained classifier

  • target_loader (DataLoader) – Target domain data (no labels needed)

Return type:

Tensor

Returns:

Estimated target label distribution

compute_shift_magnitude(model, target_loader, metric='tvd')[source]#

Compute magnitude of label shift.

Parameters:
  • model (Module) – Trained classifier

  • target_loader (DataLoader) – Target domain data

  • metric (str) – Metric to use (‘tvd’, ‘kl’, ‘l2’)

Return type:

float

Returns:

Shift magnitude

state_dict()[source]#

Save label shift detector state.

Return type:

dict

load_state_dict(state)[source]#

Load label shift detector state.

Parameters:

state (dict)

Return type:

None

save(path)[source]#

Save label shift detector state.

Parameters:

path (str)

Return type:

None

classmethod load(path, num_classes=0, calibrated=False)[source]#

Load label shift detector from file.

Parameters:
  • path (str) – File path to load the state from.

  • num_classes (int) – Ignored; restored from saved state. Kept for backward compatibility.

  • calibrated (bool) – Ignored; restored from saved state. Kept for backward compatibility.

Return type:

LabelShiftDetector