incerto.calibration.BaseCalibrator#
- class incerto.calibration.BaseCalibrator[source]#
Bases:
ABCAbstract base class for all calibration methods.
Calibrators are post-hoc methods that adjust a trained model’s predicted probabilities to better match empirical frequencies. All calibrators follow a fit-predict pattern:
fit(): Learn calibration parameters on a validation set
predict(): Apply calibration to new logits
Subclasses must implement both methods.
Example
>>> class MyCalibrator(BaseCalibrator): ... def fit(self, logits, labels): ... # Learn calibration parameters ... self.temperature = find_optimal_temperature(logits, labels) ... return self ... ... def predict(self, logits): ... # Apply calibration ... calibrated_logits = logits / self.temperature ... probs = F.softmax(calibrated_logits, dim=-1) ... return Categorical(probs=probs) ... >>> calibrator = MyCalibrator() >>> calibrator.fit(val_logits, val_labels) >>> calibrated_dist = calibrator.predict(test_logits)
See also
TemperatureScaling: Simple and effective temperature-based calibration
VectorScaling: Class-wise temperature scaling
IsotonicRegression: Non-parametric calibration
- __init__()#
Methods
__init__()fit(logits, labels)Fit the calibrator on a validation set.
load(path)Load a calibrator from a file.
load_state_dict(state)Load calibrator state from a dictionary.
predict(logits)Apply calibration to logits and return a Categorical distribution.
save(path)Save calibrator state to a file.
Return a dictionary containing the calibrator's state.
- abstract fit(logits, labels)[source]#
Fit the calibrator on a validation set.
This method should learn any calibration parameters needed to map uncalibrated logits to calibrated probabilities.
- Parameters:
- Returns:
For method chaining
- Return type:
self
Note
The validation set used here should be separate from both training and test sets to avoid overfitting the calibration parameters.
- abstract predict(logits)[source]#
Apply calibration to logits and return a Categorical distribution.
This method applies the learned calibration to transform uncalibrated logits into calibrated probabilities.
- Parameters:
logits (
Tensor) – Uncalibrated model outputs of shape (n_samples, n_classes).- Return type:
- Returns:
A torch.distributions.Categorical distribution over calibrated probabilities. Access probabilities via .probs or sample via .sample().
Example
>>> calibrated_dist = calibrator.predict(test_logits) >>> calibrated_probs = calibrated_dist.probs # shape: (n_samples, n_classes) >>> predictions = calibrated_dist.sample() # shape: (n_samples,) >>> log_probs = calibrated_dist.log_prob(labels) # shape: (n_samples,)
- abstract state_dict()[source]#
Return a dictionary containing the calibrator’s state.
This should include all learned parameters needed to reproduce the calibrator’s behavior after fitting.
- Return type:
- Returns:
Dictionary mapping parameter names to their values.
Example
>>> calibrator.fit(val_logits, val_labels) >>> state = calibrator.state_dict() >>> # state might be {'temperature': 1.5, 'n_classes': 10}
- abstract load_state_dict(state)[source]#
Load calibrator state from a dictionary.
This restores the calibrator to a previously saved state, allowing it to be used without refitting.
- Parameters:
state (
dict) – Dictionary containing calibrator state, typically from a previous call to state_dict().- Raises:
SerializationError – If state is invalid or incompatible.
- Return type:
Example
>>> state = calibrator.state_dict() >>> new_calibrator = MyCalibrator() >>> new_calibrator.load_state_dict(state) >>> # new_calibrator now behaves identically to calibrator
- save(path)[source]#
Save calibrator state to a file.
- Parameters:
path (
str) – File path where the state will be saved.- Raises:
SerializationError – If saving fails.
- Return type:
Example
>>> calibrator.fit(val_logits, val_labels) >>> calibrator.save('calibrator.pt')
- classmethod load(path)[source]#
Load a calibrator from a file.
- Parameters:
path (
str) – File path to load the state from.- Return type:
- Returns:
A new calibrator instance with loaded state.
- Raises:
SerializationError – If loading fails.
Example
>>> calibrator = MyCalibrator.load('calibrator.pt') >>> calibrated_dist = calibrator.predict(test_logits)