Source code for prosemble.core.distance

"""
JAX-based distance functions for Prosemble.

This module provides GPU-accelerated, vectorized distance computations
using JAX. All functions are JIT-compiled for maximum performance.

Mathematical Background
-----------------------
Distance metrics are fundamental to prototype-based learning algorithms.
This implementation focuses on:
1. Batch/matrix operations (no Python loops)
2. GPU compatibility
3. JIT compilation for speed
4. Numerical stability

Author: Prosemble Contributors
License: MIT
"""

import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import chex


# ============================================================================
# Core Distance Functions (Pairwise Matrices)
# ============================================================================


[docs] @jit def euclidean_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array: """ Compute pairwise Euclidean distances between rows of X and Y. Uses the expansion trick: :math:`\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y`. Formula: :math:`D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}` Args: X: Array of shape (n, d) - n samples with d features Y: Array of shape (m, d) - m samples with d features Returns: D: Array of shape (n, m) where :math:`D_{ij} = \|X_i - Y_j\|` Complexity: Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix Example: >>> X = jnp.array([[0, 0], [1, 1], [2, 2]]) >>> Y = jnp.array([[0, 0], [3, 3]]) >>> D = euclidean_distance_matrix(X, Y) >>> D.shape (3, 2) >>> D[0, 0] # Distance from X[0] to Y[0] 0.0 >>> D[2, 1] # Distance from X[2] to Y[1] 1.414... Notes: - Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives - GPU-compatible: All operations are JAX primitives - JIT-compiled: First call compiles, subsequent calls are fast """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Compute squared norms X_sq = jnp.sum(X ** 2, axis=1, keepdims=True) # (n, 1) Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True).T # (1, m) # Compute dot product XY = X @ Y.T # (n, m) # Apply distance formula D_sq = X_sq + Y_sq - 2 * XY # Ensure non-negative (numerical stability) D_sq = jnp.maximum(D_sq, 0.0) return jnp.sqrt(D_sq)
[docs] @jit def squared_euclidean_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array: """ Compute pairwise squared Euclidean distances. More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely. Formula: :math:`D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2` Args: X: Array of shape (n, d) Y: Array of shape (m, d) Returns: D: Array of shape (n, m) where :math:`D^2_{ij} = \|X_i - Y_j\|^2` Complexity: Time: O(nmd) Space: O(nm) Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D_sq = squared_euclidean_distance_matrix(X, Y) >>> D_sq[1, 1] # Squared distance from [1,1] to [2,2] 2.0 Notes: - Preferred over euclidean when squared distances are sufficient - Many algorithms (FCM, PCM) use squared distances directly - More numerically stable than squaring euclidean distances """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) X_sq = jnp.sum(X ** 2, axis=1, keepdims=True) Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True).T XY = X @ Y.T D_sq = X_sq + Y_sq - 2 * XY return jnp.maximum(D_sq, 0.0)
[docs] @jit def manhattan_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array: """ Compute pairwise Manhattan (L1) distances. Formula: :math:`D_{ij} = \|X_i - Y_j\|_1 = \sum_k |X_{ik} - Y_{jk}|` Args: X: Array of shape (n, d) Y: Array of shape (m, d) Returns: D: Array of shape (n, m) where D[i,j] is Manhattan distance Complexity: Time: O(nmd) Space: O(nmd) - intermediate broadcasting Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D = manhattan_distance_matrix(X, Y) >>> D[1, 1] # Manhattan distance from [1,1] to [2,2] 2.0 Implementation: Uses broadcasting: X[:, None, :] - Y[None, :, :] creates (n, m, d) Then sums absolute differences along feature dimension. Notes: - Also known as "taxicab" or "city block" distance - More robust to outliers than Euclidean distance - Natural for sparse/binary features """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Broadcasting: (n, 1, d) - (1, m, d) = (n, m, d) diff = X[:, None, :] - Y[None, :, :] # Sum absolute differences along feature dimension D = jnp.sum(jnp.abs(diff), axis=2) return D
[docs] def lpnorm_distance_matrix( X: chex.Array, Y: chex.Array, p: float | int ) -> chex.Array: """ Compute pairwise L-p norm distances. Formula: :math:`D_{ij} = \|X_i - Y_j\|_p = (\sum_k |X_{ik} - Y_{jk}|^p)^{1/p}` Special Cases: p = 1: Manhattan distance p = 2: Euclidean distance p = :math:`\infty`: Chebyshev distance (max absolute difference) Args: X: Array of shape (n, d) Y: Array of shape (m, d) p: Order of the norm (p >= 1) Returns: D: Array of shape (n, m) where D[i,j] is L-p distance Complexity: Time: O(nmd) Space: O(nmd) Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [3, 4]]) >>> D = lpnorm_distance_matrix(X, Y, p=2) # Euclidean >>> D = lpnorm_distance_matrix(X, Y, p=1) # Manhattan >>> D = lpnorm_distance_matrix(X, Y, p=jnp.inf) # Chebyshev Notes: - For p=1, use manhattan_distance_matrix for better performance - For p=2, use euclidean_distance_matrix for better performance - For p=inf, computes ``max(|x - y|)`` """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Broadcasting diff = X[:, None, :] - Y[None, :, :] if p == jnp.inf: # Chebyshev distance: max absolute difference D = jnp.max(jnp.abs(diff), axis=2) else: # General L-p norm D = jnp.power(jnp.sum(jnp.power(jnp.abs(diff), p), axis=2), 1.0 / p) return D
[docs] @jit def omega_distance_matrix( X: chex.Array, Y: chex.Array, omega: chex.Array ) -> chex.Array: """ Compute distances in projected space using projection matrix Omega. Formula: :math:`D_{ij} = \|X_i \Omega - Y_j \Omega\|^2` where :math:`\Omega` is a projection matrix that transforms the feature space. Args: X: Array of shape (n, d) Y: Array of shape (m, d) omega: Projection matrix of shape (d, k) where k is projection dimension Returns: D: Array of shape (n, m) with squared distances in projected space Complexity: Time: O(ndk + mdk + nmk) = O((n+m)dk + nmk) Space: O(nk + mk + nm) Use Cases: - Dimensionality reduction for distance computation - Learning relevance of features (omega learned from data) - Mahalanobis-like distances (when :math:`\Omega = L` where :math:`\Sigma = LL^T`) Example: >>> X = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> Y = jnp.array([[0, 0, 0], [1, 1, 1]]) >>> omega = jnp.array([[1, 0], [0, 1], [0, 0]]) # Project to first 2 dims >>> D = omega_distance_matrix(X, Y, omega) >>> D.shape (2, 2) Notes: - When omega is identity, reduces to squared Euclidean distance - When omega is learned, enables adaptive distance metrics - Used in GLVQ (Generalized Learning Vector Quantization) """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_rank(omega, 2) chex.assert_equal(X.shape[1], omega.shape[0]) chex.assert_equal(Y.shape[1], omega.shape[0]) # Project data to new space X_proj = X @ omega # (n, k) Y_proj = Y @ omega # (m, k) # Compute distances in projected space D_sq = squared_euclidean_distance_matrix(X_proj, Y_proj) return D_sq
[docs] @jit def lomega_distance_matrix( X: chex.Array, Y: chex.Array, omegas: chex.Array ) -> chex.Array: """ Compute distances using multiple projection matrices (Local Omega). Formula: :math:`D_{ij} = \sum_p \|X_i \Omega_p - Y_j \Omega_p\|^2` where :math:`\Omega_p` are multiple projection matrices (one per prototype or cluster). Args: X: Array of shape (n, d) - data points Y: Array of shape (m, d) - prototypes/centroids omegas: Array of shape (m, d, k) - m projection matrices of size (d, k) Each Y[j] has its own projection matrix omegas[j] Returns: D: Array of shape (n, m) with aggregated projected distances Complexity: Time: O(nmdk) Space: O(nmk) Use Cases: - Local relevance learning (each prototype has its own metric) - Adaptive distance metrics in GMLVQ - Cluster-specific feature weighting Example: >>> n, m, d, k = 10, 3, 5, 2 >>> X = jax.random.normal(jax.random.PRNGKey(0), (n, d)) >>> Y = jax.random.normal(jax.random.PRNGKey(1), (m, d)) >>> omegas = jax.random.normal(jax.random.PRNGKey(2), (m, d, k)) >>> D = lomega_distance_matrix(X, Y, omegas) >>> D.shape (10, 3) Implementation: Uses einsum for efficient tensor contraction: 1. Project X through each omega: X @ omegas[j] for all j 2. Extract diagonal for Y projections (each Y[j] uses omegas[j]) 3. Compute squared differences and sum Notes: - Generalizes omega_distance to local (per-prototype) metrics - More flexible but computationally expensive - Enables learning which features matter for each cluster """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_rank(omegas, 3) chex.assert_equal(X.shape[1], omegas.shape[1]) chex.assert_equal(Y.shape[1], omegas.shape[1]) chex.assert_equal(Y.shape[0], omegas.shape[0]) n, d = X.shape m, _, k = omegas.shape # Project X through all omegas: (n, m, d) @ (m, d, k) -> (n, m, k) # We need X[i] @ omegas[j] for all i, j X_expanded = X[:, None, :] # (n, 1, d) X_proj = jnp.einsum('nid,mdk->nmk', X_expanded, omegas) # (n, m, k) # Project Y through corresponding omegas: Y[j] @ omegas[j] # This is diagonal in the m dimension Y_proj = jnp.einsum('md,mdk->mk', Y, omegas) # (m, k) # Compute squared differences: (n, m, k) # Broadcasting: (n, m, k) - (1, m, k) diff_sq = (X_proj - Y_proj[None, :, :]) ** 2 # Sum over features and projection dimensions: (n, m, k) -> (n, m) D_sq = jnp.sum(diff_sq, axis=2) return D_sq
[docs] @jit def tangent_distance_matrix( X: chex.Array, Y: chex.Array, omegas: chex.Array ) -> chex.Array: """ Compute pairwise localized tangent distances. Each prototype j has an orthogonal subspace basis Omega_j of shape (d, s). The tangent distance projects out the subspace directions: d(x, w_j) = ||(I - Omega_j @ Omega_j^T)(x - w_j)||^2 This is equivalent to: diff = x - w_j proj = Omega_j^T @ diff (project onto subspace) recon = Omega_j @ proj (reconstruct in ambient space) tangent_diff = diff - recon (residual orthogonal to subspace) d = ||tangent_diff||^2 Parameters ---------- X : array of shape (n, d) Data points. Y : array of shape (m, d) Prototypes. omegas : array of shape (m, d, s) Orthogonal subspace bases per prototype, where s is the subspace dimension. Returns ------- D : array of shape (n, m) Squared tangent distances. Notes ----- Based on Saralajew, S., & Villmann, T. (2016). Adaptive tangent distances in generalized learning vector quantization. """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_rank(omegas, 3) chex.assert_equal(X.shape[1], Y.shape[1]) chex.assert_equal(Y.shape[0], omegas.shape[0]) chex.assert_equal(Y.shape[1], omegas.shape[1]) # diff: (n, m, d) diff = X[:, None, :] - Y[None, :, :] # Project onto each prototype's subspace: (n, m, d) @ (m, d, s) -> (n, m, s) proj = jnp.einsum('nmd,mds->nms', diff, omegas) # Reconstruct from subspace: (n, m, s) @ (m, s, d) -> (n, m, d) # omegas transposed: (m, d, s) -> (m, s, d) recon = jnp.einsum('nms,mds->nmd', proj, omegas) # Residual (orthogonal complement) tangent_diff = diff - recon # Squared norm return jnp.sum(tangent_diff ** 2, axis=2)
# ============================================================================ # Kernel Functions # ============================================================================
[docs] @jit def gaussian_kernel_matrix( X: chex.Array, Y: chex.Array, sigma: float ) -> chex.Array: """ Compute Gaussian (RBF) kernel matrix. Formula: :math:`K_{ij} = \exp(-\|X_i - Y_j\|^2 / (2\sigma^2))` The Gaussian kernel maps data to infinite-dimensional Hilbert space, enabling non-linear clustering and classification. Args: X: Array of shape (n, d) Y: Array of shape (m, d) sigma: Bandwidth parameter (:math:`\sigma > 0`) Returns: K: Array of shape (n, m) where :math:`K_{ij} \in [0, 1]`. :math:`K_{ij} = 1` when :math:`X_i = Y_j`; :math:`K_{ij} \to 0` as :math:`\|X_i - Y_j\| \to \infty` Complexity: Time: O(nmd) Space: O(nm) Properties: - K is positive semi-definite (valid kernel) - K is symmetric if X = Y - K[i,i] = 1 (self-similarity) Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> K = gaussian_kernel_matrix(X, Y, sigma=1.0) >>> K[0, 0] # Self-similarity 1.0 >>> K[0, 1] < K[0, 0] # Decreases with distance True Kernel Trick: For feature map phi mapping to infinite-dimensional Hilbert space, K(x, y) = <phi(x), phi(y)>. Kernel distance: ||phi(x) - phi(y)||^2 = K(x,x) - 2K(x,y) + K(y,y) = 2 - 2K(x,y) for normalized kernel. Use Cases: - Kernel Fuzzy C-Means (KFCM) - Kernel Possibilistic C-Means (KPCM) - Support Vector Machines (SVM) - Gaussian Processes Hyperparameter Tuning: - Small :math:`\sigma`: Tight clusters, high sensitivity to noise - Large :math:`\sigma`: Smooth clusters, may underfit - Rule of thumb: :math:`\sigma \approx \text{median}(\text{pairwise\_distances}) / \sqrt{2 \cdot n_\text{clusters}}` Notes: - sigma is bandwidth, NOT variance (variance = :math:`\sigma^2`) - For numerical stability, we use maximum() to ensure non-negative - JIT-compiled for GPU acceleration """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Compute squared distances D_sq = squared_euclidean_distance_matrix(X, Y) # Apply Gaussian kernel K = jnp.exp(-D_sq / (2 * sigma ** 2)) return K
[docs] @jit def polynomial_kernel_matrix( X: chex.Array, Y: chex.Array, degree: int = 3, coef0: float = 1.0 ) -> chex.Array: """ Compute polynomial kernel matrix. Formula: :math:`K_{ij} = (X_i^T Y_j + c)^d` where *d* is degree and *c* is coef0. Args: X: Array of shape (n, d) Y: Array of shape (m, d) degree: Polynomial degree (:math:`d \ge 1`) coef0: Coefficient (:math:`c \ge 0`) Returns: K: Array of shape (n, m) with polynomial kernel values Example: >>> X = jnp.array([[1, 2], [3, 4]]) >>> Y = jnp.array([[1, 0], [0, 1]]) >>> K = polynomial_kernel_matrix(X, Y, degree=2, coef0=1.0) Notes: - degree=1, coef0=0: Linear kernel (dot product) - Higher degree: More complex decision boundaries - coef0: Influences importance of lower vs higher order terms """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Compute dot products dot_products = X @ Y.T # Apply polynomial kernel K = jnp.power(dot_products + coef0, degree) return K
# ============================================================================ # Pairwise Distance Functions (for single pairs) # ============================================================================
[docs] @jit def euclidean_distance(x: chex.Array, y: chex.Array) -> chex.Array: """ Euclidean distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) Returns: Scalar distance Example: >>> x = jnp.array([0, 0, 0]) >>> y = jnp.array([1, 1, 1]) >>> d = euclidean_distance(x, y) >>> d Array(1.732..., dtype=float32) """ chex.assert_equal_shape([x, y]) return jnp.sqrt(jnp.sum((x - y) ** 2))
[docs] @jit def squared_euclidean_distance(x: chex.Array, y: chex.Array) -> chex.Array: """ Squared Euclidean distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) Returns: Scalar squared distance """ chex.assert_equal_shape([x, y]) return jnp.sum((x - y) ** 2)
[docs] @jit def manhattan_distance(x: chex.Array, y: chex.Array) -> chex.Array: """ Manhattan (L1) distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) Returns: Scalar distance """ chex.assert_equal_shape([x, y]) return jnp.sum(jnp.abs(x - y))
[docs] def lpnorm_distance(x: chex.Array, y: chex.Array, p: float = 2) -> chex.Array: """ Lp-norm distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) p: Order of the norm (supports inf) Returns: Scalar distance """ chex.assert_equal_shape([x, y]) return jnp.linalg.norm(x - y, ord=p)
[docs] @jit def omega_distance(x: chex.Array, y: chex.Array, omega: chex.Array) -> chex.Array: """ Omega (projection-based) distance between two vectors. Computes ||diff @ omega||² where diff = x - y. Args: x: Array of shape (d,) y: Array of shape (d,) omega: Projection matrix of shape (d, k) Returns: Scalar squared distance in projected space """ chex.assert_equal_shape([x, y]) diff = x - y projected = diff @ omega return jnp.sum(projected ** 2)
[docs] def lomega_distance(X: chex.Array, Y: chex.Array, omegas: chex.Array) -> chex.Array: """ Local omega distance with per-prototype projection matrices. Args: X: Array of shape (n, d) Y: Array of shape (m, d) — prototypes omegas: Array of shape (m, d, k) — one projection matrix per prototype Returns: Distance matrix of shape (n, m) """ def compute_single(x, y, omega): diff = x - y projected = diff @ omega return jnp.sum(projected ** 2) def compute_row(x): return jax.vmap(compute_single, in_axes=(None, 0, 0))(x, Y, omegas) return jax.vmap(compute_row)(X)
# ============================================================================ # Utility Functions # ============================================================================
[docs] def estimate_sigma(X: chex.Array, percentile: float = 50.0) -> float: """ Estimate sigma for Gaussian kernel using pairwise distances. Strategy: Use median (or other percentile) of pairwise distances. Args: X: Data array of shape (n, d) percentile: Percentile of distances to use (0-100) Returns: sigma: Estimated bandwidth parameter Example: >>> X = jax.random.normal(jax.random.PRNGKey(0), (100, 10)) >>> sigma = estimate_sigma(X, percentile=50) Notes: - Heuristic: :math:`\sigma = \text{median\_distance} / \sqrt{2 \cdot n_\text{clusters}}` - For large datasets, use subsample to avoid O(n²) computation """ # For large datasets, subsample n = X.shape[0] if n > 1000: key = jax.random.PRNGKey(0) indices = jax.random.choice(key, n, shape=(1000,), replace=False) X_sub = X[indices] else: X_sub = X # Compute pairwise distances D = euclidean_distance_matrix(X_sub, X_sub) # Get upper triangle (exclude diagonal and duplicates) mask = jnp.triu(jnp.ones_like(D, dtype=bool), k=1) distances = D[mask] # Compute percentile sigma = jnp.percentile(distances, percentile) return float(sigma)
[docs] @jit def safe_divide(numerator: chex.Array, denominator: chex.Array, epsilon: float = 1e-10) -> chex.Array: """ Safe division avoiding division by zero. Args: numerator: Numerator array denominator: Denominator array epsilon: Small value to add to denominator Returns: numerator / (denominator + epsilon) Example: >>> x = jnp.array([1.0, 2.0, 3.0]) >>> y = jnp.array([2.0, 0.0, 1.0]) >>> safe_divide(x, y) Array([0.5, 2e+09, 3.0], dtype=float32) # Avoids inf """ return numerator / (denominator + epsilon)
# ============================================================================ # Module Information # ============================================================================ # Aliases for convenience batch_squared_euclidean = squared_euclidean_distance_matrix batch_euclidean = euclidean_distance_matrix __all__ = [ # Matrix distance functions 'euclidean_distance_matrix', 'squared_euclidean_distance_matrix', 'manhattan_distance_matrix', 'lpnorm_distance_matrix', 'omega_distance_matrix', 'lomega_distance_matrix', 'tangent_distance_matrix', # Kernel functions 'gaussian_kernel_matrix', 'polynomial_kernel_matrix', # Pairwise functions 'euclidean_distance', 'squared_euclidean_distance', 'manhattan_distance', 'lpnorm_distance', 'omega_distance', 'lomega_distance', # Utilities 'estimate_sigma', 'safe_divide', # Aliases 'batch_squared_euclidean', 'batch_euclidean', ]