incerto.shift.MMDShiftDetector#

class incerto.shift.MMDShiftDetector(sigma=1.0)[source]#

Bases: BaseShiftDetector

Kernel Maximum Mean Discrepancy with Gaussian (RBF) kernel.

Computes the biased MMD estimator which includes diagonal terms. For large sample sizes, this converges to the true MMD.

Reference:

Gretton et al., “A Kernel Two-Sample Test” (JMLR 2012)

Parameters:

sigma (float) – RBF kernel bandwidth parameter (default: 1.0)

__init__(sigma=1.0)[source]#
Parameters:

sigma (float)

Methods

__init__([sigma])

fit(reference_loader)

Fit the detector on reference (source) distribution.

load(path)

Load a detector from a file.

load_state_dict(state)

Load MMD detector state.

save(path)

Save detector state to a file.

score(test_loader)

Compute shift score between reference and test distributions.

state_dict()

Save MMD detector state.

__init__(sigma=1.0)[source]#
Parameters:

sigma (float)

state_dict()[source]#

Save MMD detector state.

Return type:

dict

load_state_dict(state)[source]#

Load MMD detector state.

Parameters:

state (dict)

Return type:

None