incerto.shift.sliced_wasserstein_distance

incerto.shift.sliced_wasserstein_distance#

incerto.shift.sliced_wasserstein_distance(x, y, num_projections=100, p=2.0, seed=None)[source]#

Sliced Wasserstein distance between two empirical distributions.

Projects the distributions onto random 1D lines and computes the average Wasserstein distance across projections. This is much faster than the full Wasserstein distance and scales to high dimensions.

Parameters:
  • x (Tensor) – Source samples of shape (n, d)

  • y (Tensor) – Target samples of shape (m, d)

  • num_projections (int) – Number of random projections (default: 100)

  • p (float) – Order of the Wasserstein distance (default: 2.0)

  • seed (Optional[int]) – Random seed for reproducibility

Return type:

float

Returns:

Sliced Wasserstein distance averaged over random projections

Reference:

Rabin et al., “Wasserstein Barycenter and Its Application to Texture Mixing” (SSVM 2011) Kolouri et al., “Sliced-Wasserstein Autoencoder” (ICLR 2019)

Example

>>> source_features = model(source_data)
>>> target_features = model(target_data)
>>> distance = sliced_wasserstein_distance(source_features, target_features)