incerto.calibration.BaseCalibrator#

class incerto.calibration.BaseCalibrator[source]#

Bases: ABC

Abstract base class for all calibration methods.

Calibrators are post-hoc methods that adjust a trained model’s predicted probabilities to better match empirical frequencies. All calibrators follow a fit-predict pattern:

  1. fit(): Learn calibration parameters on a validation set

  2. predict(): Apply calibration to new logits

Subclasses must implement both methods.

Example

>>> class MyCalibrator(BaseCalibrator):
...     def fit(self, logits, labels):
...         # Learn calibration parameters
...         self.temperature = find_optimal_temperature(logits, labels)
...         return self
...
...     def predict(self, logits):
...         # Apply calibration
...         calibrated_logits = logits / self.temperature
...         probs = F.softmax(calibrated_logits, dim=-1)
...         return Categorical(probs=probs)
...
>>> calibrator = MyCalibrator()
>>> calibrator.fit(val_logits, val_labels)
>>> calibrated_dist = calibrator.predict(test_logits)

See also

  • TemperatureScaling: Simple and effective temperature-based calibration

  • VectorScaling: Class-wise temperature scaling

  • IsotonicRegression: Non-parametric calibration

__init__()#

Methods

__init__()

fit(logits, labels)

Fit the calibrator on a validation set.

load(path)

Load a calibrator from a file.

load_state_dict(state)

Load calibrator state from a dictionary.

predict(logits)

Apply calibration to logits and return a Categorical distribution.

save(path)

Save calibrator state to a file.

state_dict()

Return a dictionary containing the calibrator's state.

abstract fit(logits, labels)[source]#

Fit the calibrator on a validation set.

This method should learn any calibration parameters needed to map uncalibrated logits to calibrated probabilities.

Parameters:
  • logits (Tensor) – Uncalibrated model outputs of shape (n_samples, n_classes). These should be raw logits before softmax.

  • labels (Tensor) – True class labels of shape (n_samples,) with integer values in range [0, n_classes-1].

Returns:

For method chaining

Return type:

self

Note

The validation set used here should be separate from both training and test sets to avoid overfitting the calibration parameters.

abstract predict(logits)[source]#

Apply calibration to logits and return a Categorical distribution.

This method applies the learned calibration to transform uncalibrated logits into calibrated probabilities.

Parameters:

logits (Tensor) – Uncalibrated model outputs of shape (n_samples, n_classes).

Return type:

Categorical

Returns:

A torch.distributions.Categorical distribution over calibrated probabilities. Access probabilities via .probs or sample via .sample().

Example

>>> calibrated_dist = calibrator.predict(test_logits)
>>> calibrated_probs = calibrated_dist.probs  # shape: (n_samples, n_classes)
>>> predictions = calibrated_dist.sample()     # shape: (n_samples,)
>>> log_probs = calibrated_dist.log_prob(labels)  # shape: (n_samples,)
abstract state_dict()[source]#

Return a dictionary containing the calibrator’s state.

This should include all learned parameters needed to reproduce the calibrator’s behavior after fitting.

Return type:

dict

Returns:

Dictionary mapping parameter names to their values.

Example

>>> calibrator.fit(val_logits, val_labels)
>>> state = calibrator.state_dict()
>>> # state might be {'temperature': 1.5, 'n_classes': 10}
abstract load_state_dict(state)[source]#

Load calibrator state from a dictionary.

This restores the calibrator to a previously saved state, allowing it to be used without refitting.

Parameters:

state (dict) – Dictionary containing calibrator state, typically from a previous call to state_dict().

Raises:

SerializationError – If state is invalid or incompatible.

Return type:

None

Example

>>> state = calibrator.state_dict()
>>> new_calibrator = MyCalibrator()
>>> new_calibrator.load_state_dict(state)
>>> # new_calibrator now behaves identically to calibrator
save(path)[source]#

Save calibrator state to a file.

Parameters:

path (str) – File path where the state will be saved.

Raises:

SerializationError – If saving fails.

Return type:

None

Example

>>> calibrator.fit(val_logits, val_labels)
>>> calibrator.save('calibrator.pt')
classmethod load(path)[source]#

Load a calibrator from a file.

Parameters:

path (str) – File path to load the state from.

Return type:

BaseCalibrator

Returns:

A new calibrator instance with loaded state.

Raises:

SerializationError – If loading fails.

Example

>>> calibrator = MyCalibrator.load('calibrator.pt')
>>> calibrated_dist = calibrator.predict(test_logits)