incerto.shift.ImportanceWeightingShift#

class incerto.shift.ImportanceWeightingShift(method='logistic', alpha=0.01)[source]#

Bases: object

Importance weighting for covariate shift adaptation.

Estimates density ratio w(x) = p_target(x) / p_source(x) and uses it to re-weight training samples.

Reference:

Sugiyama et al., “Direct Importance Estimation with Model Selection” (NIPS 2007)

Parameters:
  • method (str) – Estimation method (‘kernel’, ‘logistic’, ‘kliep’)

  • alpha (float) – Regularization parameter

__init__(method='logistic', alpha=0.01)[source]#
Parameters:

Methods

__init__([method, alpha])

compute_weights(source_features)

Compute importance weights for source samples.

fit(source_features, target_features)

Estimate importance weights.

load(path[, method, alpha])

Load importance weighting from file.

load_state_dict(state)

Load importance weighting state.

save(path)

Save importance weighting state.

state_dict()

Save importance weighting state.

weighted_loss(loss, weights)

Apply importance weights to loss.

__init__(method='logistic', alpha=0.01)[source]#
Parameters:
fit(source_features, target_features)[source]#

Estimate importance weights.

Parameters:
  • source_features (Tensor) – Features from source domain (N_s, D)

  • target_features (Tensor) – Features from target domain (N_t, D)

compute_weights(source_features)[source]#

Compute importance weights for source samples.

Parameters:

source_features (Tensor) – Source domain features

Return type:

Tensor

Returns:

Importance weights

weighted_loss(loss, weights)[source]#

Apply importance weights to loss.

Parameters:
  • loss (Tensor) – Per-sample losses

  • weights (Tensor) – Importance weights

Return type:

Tensor

Returns:

Weighted average loss

state_dict()[source]#

Save importance weighting state.

Return type:

dict

load_state_dict(state)[source]#

Load importance weighting state.

Parameters:

state (dict)

Return type:

None

save(path)[source]#

Save importance weighting state.

Parameters:

path (str)

Return type:

None

classmethod load(path, method='logistic', alpha=0.01)[source]#

Load importance weighting from file.

Parameters:
Return type:

ImportanceWeightingShift