incerto.shift.BaseShiftDetector#
- class incerto.shift.BaseShiftDetector[source]#
Bases:
ABCAbstract base class for distribution shift detection methods.
Shift detectors compare two distributions (reference and test) to determine if they differ significantly. This is crucial for detecting dataset drift, covariate shift, or concept drift in deployed models.
All detectors follow a fit-score pattern: 1. fit(): Store reference (source) distribution 2. score(): Compute shift metric against test (target) distribution
Subclasses must implement _compute() which defines the specific test statistic.
Example
>>> class MyShiftDetector(BaseShiftDetector): ... def _compute(self, test: torch.Tensor) -> float: ... # Compute your shift metric ... return some_distance(self._reference, test) ... >>> detector = MyShiftDetector() >>> detector.fit(reference_loader) >>> shift_score = detector.score(test_loader) >>> if shift_score > threshold: ... print("Significant shift detected!")
- _reference#
Cached reference distribution samples after calling fit()
See also
MMDShiftDetector: Maximum Mean Discrepancy-based detection
EnergyDistanceDetector: Energy distance-based detection
KSTest: Kolmogorov-Smirnov test for 1D features
- __init__()#
Methods
__init__()fit(reference_loader)Fit the detector on reference (source) distribution.
load(path)Load a detector from a file.
load_state_dict(state)Load detector state from a dictionary.
save(path)Save detector state to a file.
score(test_loader)Compute shift score between reference and test distributions.
Return a dictionary containing the detector's state.
- fit(reference_loader)[source]#
Fit the detector on reference (source) distribution.
This method caches the reference data for comparison with test data. The reference distribution is typically your training data or a known good distribution.
- Parameters:
reference_loader (
DataLoader) – DataLoader for reference distribution. Will concatenate all batches into a single tensor.- Returns:
For method chaining
- Return type:
self
Example
>>> detector = MyShiftDetector() >>> detector.fit(train_loader).score(test_loader)
- score(test_loader)[source]#
Compute shift score between reference and test distributions.
Higher scores indicate more significant distributional shift.
- Parameters:
test_loader (
DataLoader) – DataLoader for test (target) distribution.- Return type:
- Returns:
Scalar shift score (higher = more shift detected). The scale and interpretation depend on the specific test statistic.
- Raises:
AttributeError – If fit() has not been called first.
Example
>>> shift_score = detector.score(deployment_loader) >>> print(f"Shift detected: {shift_score:.4f}")
- state_dict()[source]#
Return a dictionary containing the detector’s state.
- Return type:
- Returns:
Dictionary containing reference data and detector parameters.
- save(path)[source]#
Save detector state to a file.
- Parameters:
path (
str) – File path where the state will be saved.- Raises:
SerializationError – If saving fails.
- Return type:
Example
>>> detector.fit(reference_loader) >>> detector.save('detector_state.pt')
- classmethod load(path)[source]#
Load a detector from a file.
- Parameters:
path (
str) – File path to load the state from.- Return type:
- Returns:
A new detector instance with loaded state.
- Raises:
SerializationError – If loading fails.
Example
>>> detector = MyShiftDetector.load('detector_state.pt') >>> shift_score = detector.score(test_loader)