incerto.shift.MMDShiftDetector#
- class incerto.shift.MMDShiftDetector(sigma=1.0)[source]#
Bases:
BaseShiftDetectorKernel Maximum Mean Discrepancy with Gaussian (RBF) kernel.
Computes the biased MMD estimator which includes diagonal terms. For large sample sizes, this converges to the true MMD.
- Reference:
Gretton et al., “A Kernel Two-Sample Test” (JMLR 2012)
- Parameters:
sigma (
float) – RBF kernel bandwidth parameter (default: 1.0)
Methods
__init__([sigma])fit(reference_loader)Fit the detector on reference (source) distribution.
load(path)Load a detector from a file.
load_state_dict(state)Load MMD detector state.
save(path)Save detector state to a file.
score(test_loader)Compute shift score between reference and test distributions.
Save MMD detector state.