incerto.shift.BaseShiftDetector#

class incerto.shift.BaseShiftDetector[source]#

Bases: ABC

Abstract base class for distribution shift detection methods.

Shift detectors compare two distributions (reference and test) to determine if they differ significantly. This is crucial for detecting dataset drift, covariate shift, or concept drift in deployed models.

All detectors follow a fit-score pattern: 1. fit(): Store reference (source) distribution 2. score(): Compute shift metric against test (target) distribution

Subclasses must implement _compute() which defines the specific test statistic.

Example

>>> class MyShiftDetector(BaseShiftDetector):
...     def _compute(self, test: torch.Tensor) -> float:
...         # Compute your shift metric
...         return some_distance(self._reference, test)
...
>>> detector = MyShiftDetector()
>>> detector.fit(reference_loader)
>>> shift_score = detector.score(test_loader)
>>> if shift_score > threshold:
...     print("Significant shift detected!")
_reference#

Cached reference distribution samples after calling fit()

See also

  • MMDShiftDetector: Maximum Mean Discrepancy-based detection

  • EnergyDistanceDetector: Energy distance-based detection

  • KSTest: Kolmogorov-Smirnov test for 1D features

__init__()#

Methods

__init__()

fit(reference_loader)

Fit the detector on reference (source) distribution.

load(path)

Load a detector from a file.

load_state_dict(state)

Load detector state from a dictionary.

save(path)

Save detector state to a file.

score(test_loader)

Compute shift score between reference and test distributions.

state_dict()

Return a dictionary containing the detector's state.

fit(reference_loader)[source]#

Fit the detector on reference (source) distribution.

This method caches the reference data for comparison with test data. The reference distribution is typically your training data or a known good distribution.

Parameters:

reference_loader (DataLoader) – DataLoader for reference distribution. Will concatenate all batches into a single tensor.

Returns:

For method chaining

Return type:

self

Example

>>> detector = MyShiftDetector()
>>> detector.fit(train_loader).score(test_loader)
score(test_loader)[source]#

Compute shift score between reference and test distributions.

Higher scores indicate more significant distributional shift.

Parameters:

test_loader (DataLoader) – DataLoader for test (target) distribution.

Return type:

float

Returns:

Scalar shift score (higher = more shift detected). The scale and interpretation depend on the specific test statistic.

Raises:

AttributeError – If fit() has not been called first.

Example

>>> shift_score = detector.score(deployment_loader)
>>> print(f"Shift detected: {shift_score:.4f}")
state_dict()[source]#

Return a dictionary containing the detector’s state.

Return type:

dict

Returns:

Dictionary containing reference data and detector parameters.

load_state_dict(state)[source]#

Load detector state from a dictionary.

Parameters:

state (dict) – Dictionary containing detector state.

Raises:

SerializationError – If state is invalid.

Return type:

None

save(path)[source]#

Save detector state to a file.

Parameters:

path (str) – File path where the state will be saved.

Raises:

SerializationError – If saving fails.

Return type:

None

Example

>>> detector.fit(reference_loader)
>>> detector.save('detector_state.pt')
classmethod load(path)[source]#

Load a detector from a file.

Parameters:

path (str) – File path to load the state from.

Return type:

BaseShiftDetector

Returns:

A new detector instance with loaded state.

Raises:

SerializationError – If loading fails.

Example

>>> detector = MyShiftDetector.load('detector_state.pt')
>>> shift_score = detector.score(test_loader)