"""
JAX-based distance functions for Prosemble.
This module provides GPU-accelerated, vectorized distance computations
using JAX. All functions are JIT-compiled for maximum performance.
Mathematical Background
-----------------------
Distance metrics are fundamental to prototype-based learning algorithms.
This implementation focuses on:
1. Batch/matrix operations (no Python loops)
2. GPU compatibility
3. JIT compilation for speed
4. Numerical stability
Author: Prosemble Contributors
License: MIT
"""
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import chex
# ============================================================================
# Core Distance Functions (Pairwise Matrices)
# ============================================================================
[docs]
@jit
def euclidean_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array:
"""
Compute pairwise Euclidean distances between rows of X and Y.
Uses the expansion trick: :math:`\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y`.
Formula: :math:`D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}`
Args:
X: Array of shape (n, d) - n samples with d features
Y: Array of shape (m, d) - m samples with d features
Returns:
D: Array of shape (n, m) where :math:`D_{ij} = \|X_i - Y_j\|`
Complexity:
Time: O(nmd) - single matrix multiplication
Space: O(nm) - output matrix
Example:
>>> X = jnp.array([[0, 0], [1, 1], [2, 2]])
>>> Y = jnp.array([[0, 0], [3, 3]])
>>> D = euclidean_distance_matrix(X, Y)
>>> D.shape
(3, 2)
>>> D[0, 0] # Distance from X[0] to Y[0]
0.0
>>> D[2, 1] # Distance from X[2] to Y[1]
1.414...
Notes:
- Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives
- GPU-compatible: All operations are JAX primitives
- JIT-compiled: First call compiles, subsequent calls are fast
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_equal(X.shape[1], Y.shape[1])
# Compute squared norms
X_sq = jnp.sum(X ** 2, axis=1, keepdims=True) # (n, 1)
Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True).T # (1, m)
# Compute dot product
XY = X @ Y.T # (n, m)
# Apply distance formula
D_sq = X_sq + Y_sq - 2 * XY
# Ensure non-negative (numerical stability)
D_sq = jnp.maximum(D_sq, 0.0)
return jnp.sqrt(D_sq)
[docs]
@jit
def squared_euclidean_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array:
"""
Compute pairwise squared Euclidean distances.
More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids
the sqrt operation entirely.
Formula: :math:`D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2`
Args:
X: Array of shape (n, d)
Y: Array of shape (m, d)
Returns:
D: Array of shape (n, m) where :math:`D^2_{ij} = \|X_i - Y_j\|^2`
Complexity:
Time: O(nmd)
Space: O(nm)
Example:
>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D_sq = squared_euclidean_distance_matrix(X, Y)
>>> D_sq[1, 1] # Squared distance from [1,1] to [2,2]
2.0
Notes:
- Preferred over euclidean when squared distances are sufficient
- Many algorithms (FCM, PCM) use squared distances directly
- More numerically stable than squaring euclidean distances
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_equal(X.shape[1], Y.shape[1])
X_sq = jnp.sum(X ** 2, axis=1, keepdims=True)
Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True).T
XY = X @ Y.T
D_sq = X_sq + Y_sq - 2 * XY
return jnp.maximum(D_sq, 0.0)
[docs]
@jit
def manhattan_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array:
"""
Compute pairwise Manhattan (L1) distances.
Formula: :math:`D_{ij} = \|X_i - Y_j\|_1 = \sum_k |X_{ik} - Y_{jk}|`
Args:
X: Array of shape (n, d)
Y: Array of shape (m, d)
Returns:
D: Array of shape (n, m) where D[i,j] is Manhattan distance
Complexity:
Time: O(nmd)
Space: O(nmd) - intermediate broadcasting
Example:
>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D = manhattan_distance_matrix(X, Y)
>>> D[1, 1] # Manhattan distance from [1,1] to [2,2]
2.0
Implementation:
Uses broadcasting: X[:, None, :] - Y[None, :, :] creates (n, m, d)
Then sums absolute differences along feature dimension.
Notes:
- Also known as "taxicab" or "city block" distance
- More robust to outliers than Euclidean distance
- Natural for sparse/binary features
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_equal(X.shape[1], Y.shape[1])
# Broadcasting: (n, 1, d) - (1, m, d) = (n, m, d)
diff = X[:, None, :] - Y[None, :, :]
# Sum absolute differences along feature dimension
D = jnp.sum(jnp.abs(diff), axis=2)
return D
[docs]
def lpnorm_distance_matrix(
X: chex.Array,
Y: chex.Array,
p: float | int
) -> chex.Array:
"""
Compute pairwise L-p norm distances.
Formula: :math:`D_{ij} = \|X_i - Y_j\|_p = (\sum_k |X_{ik} - Y_{jk}|^p)^{1/p}`
Special Cases:
p = 1: Manhattan distance
p = 2: Euclidean distance
p = :math:`\infty`: Chebyshev distance (max absolute difference)
Args:
X: Array of shape (n, d)
Y: Array of shape (m, d)
p: Order of the norm (p >= 1)
Returns:
D: Array of shape (n, m) where D[i,j] is L-p distance
Complexity:
Time: O(nmd)
Space: O(nmd)
Example:
>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [3, 4]])
>>> D = lpnorm_distance_matrix(X, Y, p=2) # Euclidean
>>> D = lpnorm_distance_matrix(X, Y, p=1) # Manhattan
>>> D = lpnorm_distance_matrix(X, Y, p=jnp.inf) # Chebyshev
Notes:
- For p=1, use manhattan_distance_matrix for better performance
- For p=2, use euclidean_distance_matrix for better performance
- For p=inf, computes ``max(|x - y|)``
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_equal(X.shape[1], Y.shape[1])
# Broadcasting
diff = X[:, None, :] - Y[None, :, :]
if p == jnp.inf:
# Chebyshev distance: max absolute difference
D = jnp.max(jnp.abs(diff), axis=2)
else:
# General L-p norm
D = jnp.power(jnp.sum(jnp.power(jnp.abs(diff), p), axis=2), 1.0 / p)
return D
[docs]
@jit
def omega_distance_matrix(
X: chex.Array,
Y: chex.Array,
omega: chex.Array
) -> chex.Array:
"""
Compute distances in projected space using projection matrix Omega.
Formula: :math:`D_{ij} = \|X_i \Omega - Y_j \Omega\|^2`
where :math:`\Omega` is a projection matrix that transforms the feature space.
Args:
X: Array of shape (n, d)
Y: Array of shape (m, d)
omega: Projection matrix of shape (d, k) where k is projection dimension
Returns:
D: Array of shape (n, m) with squared distances in projected space
Complexity:
Time: O(ndk + mdk + nmk) = O((n+m)dk + nmk)
Space: O(nk + mk + nm)
Use Cases:
- Dimensionality reduction for distance computation
- Learning relevance of features (omega learned from data)
- Mahalanobis-like distances (when :math:`\Omega = L` where :math:`\Sigma = LL^T`)
Example:
>>> X = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> Y = jnp.array([[0, 0, 0], [1, 1, 1]])
>>> omega = jnp.array([[1, 0], [0, 1], [0, 0]]) # Project to first 2 dims
>>> D = omega_distance_matrix(X, Y, omega)
>>> D.shape
(2, 2)
Notes:
- When omega is identity, reduces to squared Euclidean distance
- When omega is learned, enables adaptive distance metrics
- Used in GLVQ (Generalized Learning Vector Quantization)
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_rank(omega, 2)
chex.assert_equal(X.shape[1], omega.shape[0])
chex.assert_equal(Y.shape[1], omega.shape[0])
# Project data to new space
X_proj = X @ omega # (n, k)
Y_proj = Y @ omega # (m, k)
# Compute distances in projected space
D_sq = squared_euclidean_distance_matrix(X_proj, Y_proj)
return D_sq
[docs]
@jit
def lomega_distance_matrix(
X: chex.Array,
Y: chex.Array,
omegas: chex.Array
) -> chex.Array:
"""
Compute distances using multiple projection matrices (Local Omega).
Formula: :math:`D_{ij} = \sum_p \|X_i \Omega_p - Y_j \Omega_p\|^2`
where :math:`\Omega_p` are multiple projection matrices (one per prototype or cluster).
Args:
X: Array of shape (n, d) - data points
Y: Array of shape (m, d) - prototypes/centroids
omegas: Array of shape (m, d, k) - m projection matrices of size (d, k)
Each Y[j] has its own projection matrix omegas[j]
Returns:
D: Array of shape (n, m) with aggregated projected distances
Complexity:
Time: O(nmdk)
Space: O(nmk)
Use Cases:
- Local relevance learning (each prototype has its own metric)
- Adaptive distance metrics in GMLVQ
- Cluster-specific feature weighting
Example:
>>> n, m, d, k = 10, 3, 5, 2
>>> X = jax.random.normal(jax.random.PRNGKey(0), (n, d))
>>> Y = jax.random.normal(jax.random.PRNGKey(1), (m, d))
>>> omegas = jax.random.normal(jax.random.PRNGKey(2), (m, d, k))
>>> D = lomega_distance_matrix(X, Y, omegas)
>>> D.shape
(10, 3)
Implementation:
Uses einsum for efficient tensor contraction:
1. Project X through each omega: X @ omegas[j] for all j
2. Extract diagonal for Y projections (each Y[j] uses omegas[j])
3. Compute squared differences and sum
Notes:
- Generalizes omega_distance to local (per-prototype) metrics
- More flexible but computationally expensive
- Enables learning which features matter for each cluster
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_rank(omegas, 3)
chex.assert_equal(X.shape[1], omegas.shape[1])
chex.assert_equal(Y.shape[1], omegas.shape[1])
chex.assert_equal(Y.shape[0], omegas.shape[0])
n, d = X.shape
m, _, k = omegas.shape
# Project X through all omegas: (n, m, d) @ (m, d, k) -> (n, m, k)
# We need X[i] @ omegas[j] for all i, j
X_expanded = X[:, None, :] # (n, 1, d)
X_proj = jnp.einsum('nid,mdk->nmk', X_expanded, omegas) # (n, m, k)
# Project Y through corresponding omegas: Y[j] @ omegas[j]
# This is diagonal in the m dimension
Y_proj = jnp.einsum('md,mdk->mk', Y, omegas) # (m, k)
# Compute squared differences: (n, m, k)
# Broadcasting: (n, m, k) - (1, m, k)
diff_sq = (X_proj - Y_proj[None, :, :]) ** 2
# Sum over features and projection dimensions: (n, m, k) -> (n, m)
D_sq = jnp.sum(diff_sq, axis=2)
return D_sq
[docs]
@jit
def tangent_distance_matrix(
X: chex.Array,
Y: chex.Array,
omegas: chex.Array
) -> chex.Array:
"""
Compute pairwise localized tangent distances.
Each prototype j has an orthogonal subspace basis :math:`\Omega_j` of shape (d, s).
The tangent distance projects out the subspace directions:
.. math::
d(x, w_j) = \|(I - \Omega_j \Omega_j^T)(x - w_j)\|^2
Parameters
----------
X : array of shape (n, d)
Data points.
Y : array of shape (m, d)
Prototypes.
omegas : array of shape (m, d, s)
Orthogonal subspace bases per prototype, where s is the
subspace dimension.
Returns
-------
D : array of shape (n, m)
Squared tangent distances.
Notes
-----
Based on Saralajew, S., & Villmann, T. (2016). Adaptive tangent
distances in generalized learning vector quantization.
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_rank(omegas, 3)
chex.assert_equal(X.shape[1], Y.shape[1])
chex.assert_equal(Y.shape[0], omegas.shape[0])
chex.assert_equal(Y.shape[1], omegas.shape[1])
# diff: (n, m, d)
diff = X[:, None, :] - Y[None, :, :]
# Project onto each prototype's subspace: (n, m, d) @ (m, d, s) -> (n, m, s)
proj = jnp.einsum('nmd,mds->nms', diff, omegas)
# Reconstruct from subspace: (n, m, s) @ (m, s, d) -> (n, m, d)
# omegas transposed: (m, d, s) -> (m, s, d)
recon = jnp.einsum('nms,mds->nmd', proj, omegas)
# Residual (orthogonal complement)
tangent_diff = diff - recon
# Squared norm
return jnp.sum(tangent_diff ** 2, axis=2)
# ============================================================================
# Wasserstein Distance Functions
# ============================================================================
[docs]
@jit
def wasserstein2_distance_matrix(
X: chex.Array,
means: chex.Array,
log_variances: chex.Array
) -> chex.Array:
"""
Compute pairwise squared 2-Wasserstein distances from points to Gaussian prototypes.
Each prototype is a diagonal Gaussian :math:`\\mathcal{N}(\\mu_k, \\text{diag}(\\sigma_k^2))`.
Each input point :math:`x` is treated as a Dirac delta distribution :math:`\\delta_x`.
The squared 2-Wasserstein distance from a point to a diagonal Gaussian is:
.. math::
W_2^2(\\delta_x, \\mathcal{N}(\\mu_k, \\text{diag}(\\sigma_k^2)))
= \\sum_j (x_j - \\mu_{kj})^2 + \\sum_j \\sigma_{kj}^2
This decomposes into the squared Euclidean distance from the point to the
mean, plus the total variance (trace of covariance). Prototypes with
smaller variance are effectively "more certain" and attract nearby points
more strongly.
Parameters
----------
X : array of shape (n, d)
Data points.
means : array of shape (p, d)
Prototype mean vectors.
log_variances : array of shape (p, d)
Log of prototype variances (ensures positivity via ``exp``).
Returns
-------
D : array of shape (n, p)
Squared 2-Wasserstein distances.
References
----------
.. [1] Villani, C. (2009). Optimal Transport: Old and New.
Springer. Chapter 2.
.. [2] Givens, C. R. & Shortt, R. M. (1984). A class of Wasserstein
metrics for probability distributions. Michigan Math. J., 31(2).
"""
chex.assert_rank(X, 2)
chex.assert_rank(means, 2)
chex.assert_rank(log_variances, 2)
chex.assert_equal(X.shape[1], means.shape[1])
chex.assert_equal(means.shape, log_variances.shape)
# Squared Euclidean from points to means
eucl = squared_euclidean_distance_matrix(X, means) # (n, p)
# Variance spread penalty per prototype
variances = jnp.exp(log_variances) # (p, d)
spread = jnp.sum(variances, axis=1) # (p,)
return eucl + spread[None, :]
[docs]
@jit
def wasserstein2_omega_distance_matrix(
X: chex.Array,
means: chex.Array,
log_variances: chex.Array,
omega: chex.Array
) -> chex.Array:
"""
Squared 2-Wasserstein distance with global metric adaptation.
Projects data and means through :math:`\\Omega` before computing
the Euclidean component, while variances contribute directly:
.. math::
W_2^2(x, k) = \\|\\Omega(x - \\mu_k)\\|^2 + \\sum_j \\sigma_{kj}^2
Parameters
----------
X : array of shape (n, d)
Data points.
means : array of shape (p, d)
Prototype mean vectors.
log_variances : array of shape (p, d)
Log of prototype variances.
omega : array of shape (d, l)
Global projection matrix.
Returns
-------
D : array of shape (n, p)
Squared 2-Wasserstein distances in projected space.
"""
chex.assert_rank(X, 2)
chex.assert_rank(means, 2)
chex.assert_rank(log_variances, 2)
chex.assert_rank(omega, 2)
chex.assert_equal(X.shape[1], means.shape[1])
chex.assert_equal(means.shape, log_variances.shape)
chex.assert_equal(X.shape[1], omega.shape[0])
# Project and compute squared Euclidean in projected space
X_proj = X @ omega # (n, l)
M_proj = means @ omega # (p, l)
eucl = squared_euclidean_distance_matrix(X_proj, M_proj) # (n, p)
# Variance spread penalty
variances = jnp.exp(log_variances) # (p, d)
spread = jnp.sum(variances, axis=1) # (p,)
return eucl + spread[None, :]
[docs]
@jit
def wasserstein2_relevance_distance_matrix(
X: chex.Array,
means: chex.Array,
log_variances: chex.Array,
relevances: chex.Array
) -> chex.Array:
"""
Squared 2-Wasserstein distance with feature relevance weighting.
Applies per-feature relevance weights :math:`\\lambda_j` to the
Euclidean component:
.. math::
W_2^2(x, k) = \\sum_j \\lambda_j (x_j - \\mu_{kj})^2 + \\sum_j \\sigma_{kj}^2
where :math:`\\lambda_j = \\text{softmax}(r)_j` ensures non-negative
weights that sum to 1.
Parameters
----------
X : array of shape (n, d)
Data points.
means : array of shape (p, d)
Prototype mean vectors.
log_variances : array of shape (p, d)
Log of prototype variances.
relevances : array of shape (d,)
Raw relevance logits (softmax applied internally).
Returns
-------
D : array of shape (n, p)
Relevance-weighted squared 2-Wasserstein distances.
"""
chex.assert_rank(X, 2)
chex.assert_rank(means, 2)
chex.assert_rank(log_variances, 2)
chex.assert_rank(relevances, 1)
chex.assert_equal(X.shape[1], means.shape[1])
chex.assert_equal(means.shape, log_variances.shape)
chex.assert_equal(X.shape[1], relevances.shape[0])
# Softmax relevance weights
lambdas = jax.nn.softmax(relevances) # (d,)
# Weighted squared differences
# (n, 1, d) - (1, p, d) -> (n, p, d)
diff_sq = (X[:, None, :] - means[None, :, :]) ** 2
weighted = diff_sq * lambdas[None, None, :] # (n, p, d)
eucl = jnp.sum(weighted, axis=2) # (n, p)
# Variance spread penalty
variances = jnp.exp(log_variances) # (p, d)
spread = jnp.sum(variances, axis=1) # (p,)
return eucl + spread[None, :]
# ============================================================================
# Kernel Functions
# ============================================================================
[docs]
@jit
def gaussian_kernel_matrix(
X: chex.Array,
Y: chex.Array,
sigma: float
) -> chex.Array:
"""
Compute Gaussian (RBF) kernel matrix.
Formula: :math:`K_{ij} = \exp(-\|X_i - Y_j\|^2 / (2\sigma^2))`
The Gaussian kernel maps data to infinite-dimensional Hilbert space,
enabling non-linear clustering and classification.
Args:
X: Array of shape (n, d)
Y: Array of shape (m, d)
sigma: Bandwidth parameter (:math:`\sigma > 0`)
Returns:
K: Array of shape (n, m) where :math:`K_{ij} \in [0, 1]`.
:math:`K_{ij} = 1` when :math:`X_i = Y_j`;
:math:`K_{ij} \to 0` as :math:`\|X_i - Y_j\| \to \infty`
Complexity:
Time: O(nmd)
Space: O(nm)
Properties:
- K is positive semi-definite (valid kernel)
- K is symmetric if X = Y
- K[i,i] = 1 (self-similarity)
Example:
>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> K = gaussian_kernel_matrix(X, Y, sigma=1.0)
>>> K[0, 0] # Self-similarity
1.0
>>> K[0, 1] < K[0, 0] # Decreases with distance
True
Kernel Trick:
For feature map :math:`\phi` mapping to infinite-dimensional Hilbert space,
:math:`K(x, y) = \langle\phi(x), \phi(y)\rangle`. Kernel distance:
:math:`\|\phi(x) - \phi(y)\|^2 = K(x,x) - 2K(x,y) + K(y,y) = 2 - 2K(x,y)` for normalized kernel.
Use Cases:
- Kernel Fuzzy C-Means (KFCM)
- Kernel Possibilistic C-Means (KPCM)
- Support Vector Machines (SVM)
- Gaussian Processes
Hyperparameter Tuning:
- Small :math:`\sigma`: Tight clusters, high sensitivity to noise
- Large :math:`\sigma`: Smooth clusters, may underfit
- Rule of thumb: :math:`\sigma \approx \text{median}(\text{pairwise\_distances}) / \sqrt{2 \cdot n_\text{clusters}}`
Notes:
- sigma is bandwidth, NOT variance (variance = :math:`\sigma^2`)
- For numerical stability, we use maximum() to ensure non-negative
- JIT-compiled for GPU acceleration
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_equal(X.shape[1], Y.shape[1])
# Compute squared distances
D_sq = squared_euclidean_distance_matrix(X, Y)
# Apply Gaussian kernel
K = jnp.exp(-D_sq / (2 * sigma ** 2))
return K
[docs]
@jit
def polynomial_kernel_matrix(
X: chex.Array,
Y: chex.Array,
degree: int = 3,
coef0: float = 1.0
) -> chex.Array:
"""
Compute polynomial kernel matrix.
Formula: :math:`K_{ij} = (X_i^T Y_j + c)^d`
where *d* is degree and *c* is coef0.
Args:
X: Array of shape (n, d)
Y: Array of shape (m, d)
degree: Polynomial degree (:math:`d \ge 1`)
coef0: Coefficient (:math:`c \ge 0`)
Returns:
K: Array of shape (n, m) with polynomial kernel values
Example:
>>> X = jnp.array([[1, 2], [3, 4]])
>>> Y = jnp.array([[1, 0], [0, 1]])
>>> K = polynomial_kernel_matrix(X, Y, degree=2, coef0=1.0)
Notes:
- degree=1, coef0=0: Linear kernel (dot product)
- Higher degree: More complex decision boundaries
- coef0: Influences importance of lower vs higher order terms
"""
chex.assert_rank(X, 2)
chex.assert_rank(Y, 2)
chex.assert_equal(X.shape[1], Y.shape[1])
# Compute dot products
dot_products = X @ Y.T
# Apply polynomial kernel
K = jnp.power(dot_products + coef0, degree)
return K
# ============================================================================
# Pairwise Distance Functions (for single pairs)
# ============================================================================
[docs]
@jit
def euclidean_distance(x: chex.Array, y: chex.Array) -> chex.Array:
"""
Euclidean distance between two vectors.
Args:
x: Array of shape (d,)
y: Array of shape (d,)
Returns:
Scalar distance
Example:
>>> x = jnp.array([0, 0, 0])
>>> y = jnp.array([1, 1, 1])
>>> d = euclidean_distance(x, y)
>>> d
Array(1.732..., dtype=float32)
"""
chex.assert_equal_shape([x, y])
return jnp.sqrt(jnp.sum((x - y) ** 2))
[docs]
@jit
def squared_euclidean_distance(x: chex.Array, y: chex.Array) -> chex.Array:
"""
Squared Euclidean distance between two vectors.
Args:
x: Array of shape (d,)
y: Array of shape (d,)
Returns:
Scalar squared distance
"""
chex.assert_equal_shape([x, y])
return jnp.sum((x - y) ** 2)
[docs]
@jit
def manhattan_distance(x: chex.Array, y: chex.Array) -> chex.Array:
"""
Manhattan (L1) distance between two vectors.
Args:
x: Array of shape (d,)
y: Array of shape (d,)
Returns:
Scalar distance
"""
chex.assert_equal_shape([x, y])
return jnp.sum(jnp.abs(x - y))
[docs]
def lpnorm_distance(x: chex.Array, y: chex.Array, p: float = 2) -> chex.Array:
"""
Lp-norm distance between two vectors.
Args:
x: Array of shape (d,)
y: Array of shape (d,)
p: Order of the norm (supports inf)
Returns:
Scalar distance
"""
chex.assert_equal_shape([x, y])
return jnp.linalg.norm(x - y, ord=p)
[docs]
@jit
def omega_distance(x: chex.Array, y: chex.Array, omega: chex.Array) -> chex.Array:
"""
Omega (projection-based) distance between two vectors.
Computes :math:`\|\text{diff} \cdot \Omega\|^2` where :math:`\text{diff} = x - y`.
Args:
x: Array of shape (d,)
y: Array of shape (d,)
omega: Projection matrix of shape (d, k)
Returns:
Scalar squared distance in projected space
"""
chex.assert_equal_shape([x, y])
diff = x - y
projected = diff @ omega
return jnp.sum(projected ** 2)
[docs]
def lomega_distance(X: chex.Array, Y: chex.Array, omegas: chex.Array) -> chex.Array:
"""
Local omega distance with per-prototype projection matrices.
Args:
X: Array of shape (n, d)
Y: Array of shape (m, d) — prototypes
omegas: Array of shape (m, d, k) — one projection matrix per prototype
Returns:
Distance matrix of shape (n, m)
"""
def compute_single(x, y, omega):
diff = x - y
projected = diff @ omega
return jnp.sum(projected ** 2)
def compute_row(x):
return jax.vmap(compute_single, in_axes=(None, 0, 0))(x, Y, omegas)
return jax.vmap(compute_row)(X)
# ============================================================================
# Utility Functions
# ============================================================================
[docs]
def estimate_sigma(X: chex.Array, percentile: float = 50.0) -> float:
"""
Estimate sigma for Gaussian kernel using pairwise distances.
Strategy: Use median (or other percentile) of pairwise distances.
Args:
X: Data array of shape (n, d)
percentile: Percentile of distances to use (0-100)
Returns:
sigma: Estimated bandwidth parameter
Example:
>>> X = jax.random.normal(jax.random.PRNGKey(0), (100, 10))
>>> sigma = estimate_sigma(X, percentile=50)
Notes:
- Heuristic: :math:`\sigma = \text{median\_distance} / \sqrt{2 \cdot n_\text{clusters}}`
- For large datasets, use subsample to avoid O(n²) computation
"""
# For large datasets, subsample
n = X.shape[0]
if n > 1000:
key = jax.random.PRNGKey(0)
indices = jax.random.choice(key, n, shape=(1000,), replace=False)
X_sub = X[indices]
else:
X_sub = X
# Compute pairwise distances
D = euclidean_distance_matrix(X_sub, X_sub)
# Get upper triangle (exclude diagonal and duplicates)
mask = jnp.triu(jnp.ones_like(D, dtype=bool), k=1)
distances = D[mask]
# Compute percentile
sigma = jnp.percentile(distances, percentile)
return float(sigma)
[docs]
@jit
def safe_divide(numerator: chex.Array, denominator: chex.Array, epsilon: float = 1e-10) -> chex.Array:
"""
Safe division avoiding division by zero.
Args:
numerator: Numerator array
denominator: Denominator array
epsilon: Small value to add to denominator
Returns:
numerator / (denominator + epsilon)
Example:
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> y = jnp.array([2.0, 0.0, 1.0])
>>> safe_divide(x, y)
Array([0.5, 2e+09, 3.0], dtype=float32) # Avoids inf
"""
return numerator / (denominator + epsilon)
# ============================================================================
# Module Information
# ============================================================================
# Aliases for convenience
batch_squared_euclidean = squared_euclidean_distance_matrix
batch_euclidean = euclidean_distance_matrix
__all__ = [
# Matrix distance functions
'euclidean_distance_matrix',
'squared_euclidean_distance_matrix',
'manhattan_distance_matrix',
'lpnorm_distance_matrix',
'omega_distance_matrix',
'lomega_distance_matrix',
'tangent_distance_matrix',
# Wasserstein distance functions
'wasserstein2_distance_matrix',
'wasserstein2_omega_distance_matrix',
'wasserstein2_relevance_distance_matrix',
# Kernel functions
'gaussian_kernel_matrix',
'polynomial_kernel_matrix',
# Pairwise functions
'euclidean_distance',
'squared_euclidean_distance',
'manhattan_distance',
'lpnorm_distance',
'omega_distance',
'lomega_distance',
# Utilities
'estimate_sigma',
'safe_divide',
# Aliases
'batch_squared_euclidean',
'batch_euclidean',
]