"""
incerto.shift_detection.metrics
===============================
Pure functions that measure how far two sample sets differ.
Designed so that you can mix and match with your own detectors.
"""
from __future__ import annotations
import torch
def energy_distance(x: torch.Tensor, y: torch.Tensor) -> float:
"""Szekely–Rizzo energy distance, O(n²) naive implementation."""
def pdist(t): # pairwise ℓ2
return torch.cdist(t, t, p=2).mean()
return (2 * torch.cdist(x, y, p=2).mean() - pdist(x) - pdist(y)).item()
def total_variation(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-9) -> float:
"""Total variation between *discrete* distributions p and q."""
p = p / (p.sum() + eps)
q = q / (q.sum() + eps)
return 0.5 * torch.abs(p - q).sum().item()
def population_stability_index(p_hist, q_hist, eps: float = 1e-9) -> float:
"""Classic tabular PSI used in credit scoring."""
p, q = p_hist + eps, q_hist + eps
return ((q - p) * torch.log(q / p)).sum().item()
[docs]
def wasserstein_distance(
x: torch.Tensor, y: torch.Tensor, p: float = 2.0, max_iter: int = 100
) -> float:
"""
Wasserstein distance (Earth Mover's Distance) between two empirical distributions.
Computes the p-Wasserstein distance between two samples using Sinkhorn algorithm
for optimal transport. For p=1, this is the classic Earth Mover's Distance.
Args:
x: Source samples of shape (n, d)
y: Target samples of shape (m, d)
p: Order of the Wasserstein distance (default: 2.0 for W2 distance)
max_iter: Maximum iterations for Sinkhorn algorithm
Returns:
Wasserstein distance between the two distributions
Reference:
Cuturi, "Sinkhorn Distances: Lightspeed Computation of Optimal Transport"
(NeurIPS 2013)
Example:
>>> source_features = model(source_data)
>>> target_features = model(target_data)
>>> distance = wasserstein_distance(source_features, target_features)
"""
# For 1D case, use closed-form solution
if x.shape[1] == 1:
x_sorted = torch.sort(x.squeeze())[0]
y_sorted = torch.sort(y.squeeze())[0]
# Interpolate to same size
n, m = len(x_sorted), len(y_sorted)
if n != m:
max_size = max(n, m)
x_interp = torch.nn.functional.interpolate(
x_sorted.view(1, 1, -1),
size=max_size,
mode="linear",
align_corners=True,
).squeeze()
y_interp = torch.nn.functional.interpolate(
y_sorted.view(1, 1, -1),
size=max_size,
mode="linear",
align_corners=True,
).squeeze()
else:
x_interp, y_interp = x_sorted, y_sorted
return torch.mean(torch.abs(x_interp - y_interp) ** p).item() ** (1.0 / p)
# For multi-dimensional case, use Sinkhorn algorithm
n, m = x.shape[0], y.shape[0]
# Compute cost matrix (pairwise distances)
C = torch.cdist(x, y, p=p) ** p
# Uniform distribution over samples
a = torch.ones(n, device=x.device) / n
b = torch.ones(m, device=y.device) / m
# Sinkhorn iterations with entropy regularization
epsilon = 0.1 # regularization parameter
K = torch.exp(-C / epsilon)
u = torch.ones(n, device=x.device) / n
v = torch.ones(m, device=y.device) / m
for _ in range(max_iter):
u = a / (K @ v + 1e-10)
v = b / (K.T @ u + 1e-10)
# Compute optimal transport cost
transport_plan = u.unsqueeze(1) * K * v.unsqueeze(0)
cost = (transport_plan * C).sum().item()
return cost ** (1.0 / p)
[docs]
def sliced_wasserstein_distance(
x: torch.Tensor,
y: torch.Tensor,
num_projections: int = 100,
p: float = 2.0,
seed: int | None = None,
) -> float:
"""
Sliced Wasserstein distance between two empirical distributions.
Projects the distributions onto random 1D lines and computes the average
Wasserstein distance across projections. This is much faster than the full
Wasserstein distance and scales to high dimensions.
Args:
x: Source samples of shape (n, d)
y: Target samples of shape (m, d)
num_projections: Number of random projections (default: 100)
p: Order of the Wasserstein distance (default: 2.0)
seed: Random seed for reproducibility
Returns:
Sliced Wasserstein distance averaged over random projections
Reference:
Rabin et al., "Wasserstein Barycenter and Its Application to Texture Mixing"
(SSVM 2011)
Kolouri et al., "Sliced-Wasserstein Autoencoder" (ICLR 2019)
Example:
>>> source_features = model(source_data)
>>> target_features = model(target_data)
>>> distance = sliced_wasserstein_distance(source_features, target_features)
"""
if seed is not None:
torch.manual_seed(seed)
d = x.shape[1]
distances = []
for _ in range(num_projections):
# Random projection direction
theta = torch.randn(d, device=x.device)
theta = theta / torch.norm(theta)
# Project samples onto theta
x_proj = x @ theta
y_proj = y @ theta
# Sort projected samples
x_sorted = torch.sort(x_proj)[0]
y_sorted = torch.sort(y_proj)[0]
# Compute 1D Wasserstein distance
# If different sizes, interpolate to match
n, m = len(x_sorted), len(y_sorted)
if n != m:
max_size = max(n, m)
x_interp = torch.nn.functional.interpolate(
x_sorted.view(1, 1, -1),
size=max_size,
mode="linear",
align_corners=True,
).squeeze()
y_interp = torch.nn.functional.interpolate(
y_sorted.view(1, 1, -1),
size=max_size,
mode="linear",
align_corners=True,
).squeeze()
else:
x_interp, y_interp = x_sorted, y_sorted
# p-Wasserstein distance in 1D
dist = torch.mean(torch.abs(x_interp - y_interp) ** p) ** (1.0 / p)
distances.append(dist)
return torch.stack(distances).mean().item()