Source code for prosemble.core.distance

"""
JAX-based distance functions for Prosemble.

This module provides GPU-accelerated, vectorized distance computations
using JAX. All functions are JIT-compiled for maximum performance.

Mathematical Background
-----------------------
Distance metrics are fundamental to prototype-based learning algorithms.
This implementation focuses on:
1. Batch/matrix operations (no Python loops)
2. GPU compatibility
3. JIT compilation for speed
4. Numerical stability

Author: Prosemble Contributors
License: MIT
"""

import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import chex


# ============================================================================
# Core Distance Functions (Pairwise Matrices)
# ============================================================================


[docs] @jit def euclidean_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array: """ Compute pairwise Euclidean distances between rows of X and Y. Uses the expansion trick: :math:`\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y`. Formula: :math:`D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}` Args: X: Array of shape (n, d) - n samples with d features Y: Array of shape (m, d) - m samples with d features Returns: D: Array of shape (n, m) where :math:`D_{ij} = \|X_i - Y_j\|` Complexity: Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix Example: >>> X = jnp.array([[0, 0], [1, 1], [2, 2]]) >>> Y = jnp.array([[0, 0], [3, 3]]) >>> D = euclidean_distance_matrix(X, Y) >>> D.shape (3, 2) >>> D[0, 0] # Distance from X[0] to Y[0] 0.0 >>> D[2, 1] # Distance from X[2] to Y[1] 1.414... Notes: - Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives - GPU-compatible: All operations are JAX primitives - JIT-compiled: First call compiles, subsequent calls are fast """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Compute squared norms X_sq = jnp.sum(X ** 2, axis=1, keepdims=True) # (n, 1) Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True).T # (1, m) # Compute dot product XY = X @ Y.T # (n, m) # Apply distance formula D_sq = X_sq + Y_sq - 2 * XY # Ensure non-negative (numerical stability) D_sq = jnp.maximum(D_sq, 0.0) return jnp.sqrt(D_sq)
[docs] @jit def squared_euclidean_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array: """ Compute pairwise squared Euclidean distances. More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely. Formula: :math:`D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2` Args: X: Array of shape (n, d) Y: Array of shape (m, d) Returns: D: Array of shape (n, m) where :math:`D^2_{ij} = \|X_i - Y_j\|^2` Complexity: Time: O(nmd) Space: O(nm) Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D_sq = squared_euclidean_distance_matrix(X, Y) >>> D_sq[1, 1] # Squared distance from [1,1] to [2,2] 2.0 Notes: - Preferred over euclidean when squared distances are sufficient - Many algorithms (FCM, PCM) use squared distances directly - More numerically stable than squaring euclidean distances """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) X_sq = jnp.sum(X ** 2, axis=1, keepdims=True) Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True).T XY = X @ Y.T D_sq = X_sq + Y_sq - 2 * XY return jnp.maximum(D_sq, 0.0)
[docs] @jit def manhattan_distance_matrix(X: chex.Array, Y: chex.Array) -> chex.Array: """ Compute pairwise Manhattan (L1) distances. Formula: :math:`D_{ij} = \|X_i - Y_j\|_1 = \sum_k |X_{ik} - Y_{jk}|` Args: X: Array of shape (n, d) Y: Array of shape (m, d) Returns: D: Array of shape (n, m) where D[i,j] is Manhattan distance Complexity: Time: O(nmd) Space: O(nmd) - intermediate broadcasting Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D = manhattan_distance_matrix(X, Y) >>> D[1, 1] # Manhattan distance from [1,1] to [2,2] 2.0 Implementation: Uses broadcasting: X[:, None, :] - Y[None, :, :] creates (n, m, d) Then sums absolute differences along feature dimension. Notes: - Also known as "taxicab" or "city block" distance - More robust to outliers than Euclidean distance - Natural for sparse/binary features """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Broadcasting: (n, 1, d) - (1, m, d) = (n, m, d) diff = X[:, None, :] - Y[None, :, :] # Sum absolute differences along feature dimension D = jnp.sum(jnp.abs(diff), axis=2) return D
[docs] def lpnorm_distance_matrix( X: chex.Array, Y: chex.Array, p: float | int ) -> chex.Array: """ Compute pairwise L-p norm distances. Formula: :math:`D_{ij} = \|X_i - Y_j\|_p = (\sum_k |X_{ik} - Y_{jk}|^p)^{1/p}` Special Cases: p = 1: Manhattan distance p = 2: Euclidean distance p = :math:`\infty`: Chebyshev distance (max absolute difference) Args: X: Array of shape (n, d) Y: Array of shape (m, d) p: Order of the norm (p >= 1) Returns: D: Array of shape (n, m) where D[i,j] is L-p distance Complexity: Time: O(nmd) Space: O(nmd) Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [3, 4]]) >>> D = lpnorm_distance_matrix(X, Y, p=2) # Euclidean >>> D = lpnorm_distance_matrix(X, Y, p=1) # Manhattan >>> D = lpnorm_distance_matrix(X, Y, p=jnp.inf) # Chebyshev Notes: - For p=1, use manhattan_distance_matrix for better performance - For p=2, use euclidean_distance_matrix for better performance - For p=inf, computes ``max(|x - y|)`` """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Broadcasting diff = X[:, None, :] - Y[None, :, :] if p == jnp.inf: # Chebyshev distance: max absolute difference D = jnp.max(jnp.abs(diff), axis=2) else: # General L-p norm D = jnp.power(jnp.sum(jnp.power(jnp.abs(diff), p), axis=2), 1.0 / p) return D
[docs] @jit def omega_distance_matrix( X: chex.Array, Y: chex.Array, omega: chex.Array ) -> chex.Array: """ Compute distances in projected space using projection matrix Omega. Formula: :math:`D_{ij} = \|X_i \Omega - Y_j \Omega\|^2` where :math:`\Omega` is a projection matrix that transforms the feature space. Args: X: Array of shape (n, d) Y: Array of shape (m, d) omega: Projection matrix of shape (d, k) where k is projection dimension Returns: D: Array of shape (n, m) with squared distances in projected space Complexity: Time: O(ndk + mdk + nmk) = O((n+m)dk + nmk) Space: O(nk + mk + nm) Use Cases: - Dimensionality reduction for distance computation - Learning relevance of features (omega learned from data) - Mahalanobis-like distances (when :math:`\Omega = L` where :math:`\Sigma = LL^T`) Example: >>> X = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> Y = jnp.array([[0, 0, 0], [1, 1, 1]]) >>> omega = jnp.array([[1, 0], [0, 1], [0, 0]]) # Project to first 2 dims >>> D = omega_distance_matrix(X, Y, omega) >>> D.shape (2, 2) Notes: - When omega is identity, reduces to squared Euclidean distance - When omega is learned, enables adaptive distance metrics - Used in GLVQ (Generalized Learning Vector Quantization) """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_rank(omega, 2) chex.assert_equal(X.shape[1], omega.shape[0]) chex.assert_equal(Y.shape[1], omega.shape[0]) # Project data to new space X_proj = X @ omega # (n, k) Y_proj = Y @ omega # (m, k) # Compute distances in projected space D_sq = squared_euclidean_distance_matrix(X_proj, Y_proj) return D_sq
[docs] @jit def lomega_distance_matrix( X: chex.Array, Y: chex.Array, omegas: chex.Array ) -> chex.Array: """ Compute distances using multiple projection matrices (Local Omega). Formula: :math:`D_{ij} = \sum_p \|X_i \Omega_p - Y_j \Omega_p\|^2` where :math:`\Omega_p` are multiple projection matrices (one per prototype or cluster). Args: X: Array of shape (n, d) - data points Y: Array of shape (m, d) - prototypes/centroids omegas: Array of shape (m, d, k) - m projection matrices of size (d, k) Each Y[j] has its own projection matrix omegas[j] Returns: D: Array of shape (n, m) with aggregated projected distances Complexity: Time: O(nmdk) Space: O(nmk) Use Cases: - Local relevance learning (each prototype has its own metric) - Adaptive distance metrics in GMLVQ - Cluster-specific feature weighting Example: >>> n, m, d, k = 10, 3, 5, 2 >>> X = jax.random.normal(jax.random.PRNGKey(0), (n, d)) >>> Y = jax.random.normal(jax.random.PRNGKey(1), (m, d)) >>> omegas = jax.random.normal(jax.random.PRNGKey(2), (m, d, k)) >>> D = lomega_distance_matrix(X, Y, omegas) >>> D.shape (10, 3) Implementation: Uses einsum for efficient tensor contraction: 1. Project X through each omega: X @ omegas[j] for all j 2. Extract diagonal for Y projections (each Y[j] uses omegas[j]) 3. Compute squared differences and sum Notes: - Generalizes omega_distance to local (per-prototype) metrics - More flexible but computationally expensive - Enables learning which features matter for each cluster """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_rank(omegas, 3) chex.assert_equal(X.shape[1], omegas.shape[1]) chex.assert_equal(Y.shape[1], omegas.shape[1]) chex.assert_equal(Y.shape[0], omegas.shape[0]) n, d = X.shape m, _, k = omegas.shape # Project X through all omegas: (n, m, d) @ (m, d, k) -> (n, m, k) # We need X[i] @ omegas[j] for all i, j X_expanded = X[:, None, :] # (n, 1, d) X_proj = jnp.einsum('nid,mdk->nmk', X_expanded, omegas) # (n, m, k) # Project Y through corresponding omegas: Y[j] @ omegas[j] # This is diagonal in the m dimension Y_proj = jnp.einsum('md,mdk->mk', Y, omegas) # (m, k) # Compute squared differences: (n, m, k) # Broadcasting: (n, m, k) - (1, m, k) diff_sq = (X_proj - Y_proj[None, :, :]) ** 2 # Sum over features and projection dimensions: (n, m, k) -> (n, m) D_sq = jnp.sum(diff_sq, axis=2) return D_sq
[docs] @jit def tangent_distance_matrix( X: chex.Array, Y: chex.Array, omegas: chex.Array ) -> chex.Array: """ Compute pairwise localized tangent distances. Each prototype j has an orthogonal subspace basis :math:`\Omega_j` of shape (d, s). The tangent distance projects out the subspace directions: .. math:: d(x, w_j) = \|(I - \Omega_j \Omega_j^T)(x - w_j)\|^2 Parameters ---------- X : array of shape (n, d) Data points. Y : array of shape (m, d) Prototypes. omegas : array of shape (m, d, s) Orthogonal subspace bases per prototype, where s is the subspace dimension. Returns ------- D : array of shape (n, m) Squared tangent distances. Notes ----- Based on Saralajew, S., & Villmann, T. (2016). Adaptive tangent distances in generalized learning vector quantization. """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_rank(omegas, 3) chex.assert_equal(X.shape[1], Y.shape[1]) chex.assert_equal(Y.shape[0], omegas.shape[0]) chex.assert_equal(Y.shape[1], omegas.shape[1]) # diff: (n, m, d) diff = X[:, None, :] - Y[None, :, :] # Project onto each prototype's subspace: (n, m, d) @ (m, d, s) -> (n, m, s) proj = jnp.einsum('nmd,mds->nms', diff, omegas) # Reconstruct from subspace: (n, m, s) @ (m, s, d) -> (n, m, d) # omegas transposed: (m, d, s) -> (m, s, d) recon = jnp.einsum('nms,mds->nmd', proj, omegas) # Residual (orthogonal complement) tangent_diff = diff - recon # Squared norm return jnp.sum(tangent_diff ** 2, axis=2)
# ============================================================================ # Wasserstein Distance Functions # ============================================================================
[docs] @jit def wasserstein2_distance_matrix( X: chex.Array, means: chex.Array, log_variances: chex.Array ) -> chex.Array: """ Compute pairwise squared 2-Wasserstein distances from points to Gaussian prototypes. Each prototype is a diagonal Gaussian :math:`\\mathcal{N}(\\mu_k, \\text{diag}(\\sigma_k^2))`. Each input point :math:`x` is treated as a Dirac delta distribution :math:`\\delta_x`. The squared 2-Wasserstein distance from a point to a diagonal Gaussian is: .. math:: W_2^2(\\delta_x, \\mathcal{N}(\\mu_k, \\text{diag}(\\sigma_k^2))) = \\sum_j (x_j - \\mu_{kj})^2 + \\sum_j \\sigma_{kj}^2 This decomposes into the squared Euclidean distance from the point to the mean, plus the total variance (trace of covariance). Prototypes with smaller variance are effectively "more certain" and attract nearby points more strongly. Parameters ---------- X : array of shape (n, d) Data points. means : array of shape (p, d) Prototype mean vectors. log_variances : array of shape (p, d) Log of prototype variances (ensures positivity via ``exp``). Returns ------- D : array of shape (n, p) Squared 2-Wasserstein distances. References ---------- .. [1] Villani, C. (2009). Optimal Transport: Old and New. Springer. Chapter 2. .. [2] Givens, C. R. & Shortt, R. M. (1984). A class of Wasserstein metrics for probability distributions. Michigan Math. J., 31(2). """ chex.assert_rank(X, 2) chex.assert_rank(means, 2) chex.assert_rank(log_variances, 2) chex.assert_equal(X.shape[1], means.shape[1]) chex.assert_equal(means.shape, log_variances.shape) # Squared Euclidean from points to means eucl = squared_euclidean_distance_matrix(X, means) # (n, p) # Variance spread penalty per prototype variances = jnp.exp(log_variances) # (p, d) spread = jnp.sum(variances, axis=1) # (p,) return eucl + spread[None, :]
[docs] @jit def wasserstein2_omega_distance_matrix( X: chex.Array, means: chex.Array, log_variances: chex.Array, omega: chex.Array ) -> chex.Array: """ Squared 2-Wasserstein distance with global metric adaptation. Projects data and means through :math:`\\Omega` before computing the Euclidean component, while variances contribute directly: .. math:: W_2^2(x, k) = \\|\\Omega(x - \\mu_k)\\|^2 + \\sum_j \\sigma_{kj}^2 Parameters ---------- X : array of shape (n, d) Data points. means : array of shape (p, d) Prototype mean vectors. log_variances : array of shape (p, d) Log of prototype variances. omega : array of shape (d, l) Global projection matrix. Returns ------- D : array of shape (n, p) Squared 2-Wasserstein distances in projected space. """ chex.assert_rank(X, 2) chex.assert_rank(means, 2) chex.assert_rank(log_variances, 2) chex.assert_rank(omega, 2) chex.assert_equal(X.shape[1], means.shape[1]) chex.assert_equal(means.shape, log_variances.shape) chex.assert_equal(X.shape[1], omega.shape[0]) # Project and compute squared Euclidean in projected space X_proj = X @ omega # (n, l) M_proj = means @ omega # (p, l) eucl = squared_euclidean_distance_matrix(X_proj, M_proj) # (n, p) # Variance spread penalty variances = jnp.exp(log_variances) # (p, d) spread = jnp.sum(variances, axis=1) # (p,) return eucl + spread[None, :]
[docs] @jit def wasserstein2_relevance_distance_matrix( X: chex.Array, means: chex.Array, log_variances: chex.Array, relevances: chex.Array ) -> chex.Array: """ Squared 2-Wasserstein distance with feature relevance weighting. Applies per-feature relevance weights :math:`\\lambda_j` to the Euclidean component: .. math:: W_2^2(x, k) = \\sum_j \\lambda_j (x_j - \\mu_{kj})^2 + \\sum_j \\sigma_{kj}^2 where :math:`\\lambda_j = \\text{softmax}(r)_j` ensures non-negative weights that sum to 1. Parameters ---------- X : array of shape (n, d) Data points. means : array of shape (p, d) Prototype mean vectors. log_variances : array of shape (p, d) Log of prototype variances. relevances : array of shape (d,) Raw relevance logits (softmax applied internally). Returns ------- D : array of shape (n, p) Relevance-weighted squared 2-Wasserstein distances. """ chex.assert_rank(X, 2) chex.assert_rank(means, 2) chex.assert_rank(log_variances, 2) chex.assert_rank(relevances, 1) chex.assert_equal(X.shape[1], means.shape[1]) chex.assert_equal(means.shape, log_variances.shape) chex.assert_equal(X.shape[1], relevances.shape[0]) # Softmax relevance weights lambdas = jax.nn.softmax(relevances) # (d,) # Weighted squared differences # (n, 1, d) - (1, p, d) -> (n, p, d) diff_sq = (X[:, None, :] - means[None, :, :]) ** 2 weighted = diff_sq * lambdas[None, None, :] # (n, p, d) eucl = jnp.sum(weighted, axis=2) # (n, p) # Variance spread penalty variances = jnp.exp(log_variances) # (p, d) spread = jnp.sum(variances, axis=1) # (p,) return eucl + spread[None, :]
# ============================================================================ # Kernel Functions # ============================================================================
[docs] @jit def gaussian_kernel_matrix( X: chex.Array, Y: chex.Array, sigma: float ) -> chex.Array: """ Compute Gaussian (RBF) kernel matrix. Formula: :math:`K_{ij} = \exp(-\|X_i - Y_j\|^2 / (2\sigma^2))` The Gaussian kernel maps data to infinite-dimensional Hilbert space, enabling non-linear clustering and classification. Args: X: Array of shape (n, d) Y: Array of shape (m, d) sigma: Bandwidth parameter (:math:`\sigma > 0`) Returns: K: Array of shape (n, m) where :math:`K_{ij} \in [0, 1]`. :math:`K_{ij} = 1` when :math:`X_i = Y_j`; :math:`K_{ij} \to 0` as :math:`\|X_i - Y_j\| \to \infty` Complexity: Time: O(nmd) Space: O(nm) Properties: - K is positive semi-definite (valid kernel) - K is symmetric if X = Y - K[i,i] = 1 (self-similarity) Example: >>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> K = gaussian_kernel_matrix(X, Y, sigma=1.0) >>> K[0, 0] # Self-similarity 1.0 >>> K[0, 1] < K[0, 0] # Decreases with distance True Kernel Trick: For feature map :math:`\phi` mapping to infinite-dimensional Hilbert space, :math:`K(x, y) = \langle\phi(x), \phi(y)\rangle`. Kernel distance: :math:`\|\phi(x) - \phi(y)\|^2 = K(x,x) - 2K(x,y) + K(y,y) = 2 - 2K(x,y)` for normalized kernel. Use Cases: - Kernel Fuzzy C-Means (KFCM) - Kernel Possibilistic C-Means (KPCM) - Support Vector Machines (SVM) - Gaussian Processes Hyperparameter Tuning: - Small :math:`\sigma`: Tight clusters, high sensitivity to noise - Large :math:`\sigma`: Smooth clusters, may underfit - Rule of thumb: :math:`\sigma \approx \text{median}(\text{pairwise\_distances}) / \sqrt{2 \cdot n_\text{clusters}}` Notes: - sigma is bandwidth, NOT variance (variance = :math:`\sigma^2`) - For numerical stability, we use maximum() to ensure non-negative - JIT-compiled for GPU acceleration """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Compute squared distances D_sq = squared_euclidean_distance_matrix(X, Y) # Apply Gaussian kernel K = jnp.exp(-D_sq / (2 * sigma ** 2)) return K
[docs] @jit def polynomial_kernel_matrix( X: chex.Array, Y: chex.Array, degree: int = 3, coef0: float = 1.0 ) -> chex.Array: """ Compute polynomial kernel matrix. Formula: :math:`K_{ij} = (X_i^T Y_j + c)^d` where *d* is degree and *c* is coef0. Args: X: Array of shape (n, d) Y: Array of shape (m, d) degree: Polynomial degree (:math:`d \ge 1`) coef0: Coefficient (:math:`c \ge 0`) Returns: K: Array of shape (n, m) with polynomial kernel values Example: >>> X = jnp.array([[1, 2], [3, 4]]) >>> Y = jnp.array([[1, 0], [0, 1]]) >>> K = polynomial_kernel_matrix(X, Y, degree=2, coef0=1.0) Notes: - degree=1, coef0=0: Linear kernel (dot product) - Higher degree: More complex decision boundaries - coef0: Influences importance of lower vs higher order terms """ chex.assert_rank(X, 2) chex.assert_rank(Y, 2) chex.assert_equal(X.shape[1], Y.shape[1]) # Compute dot products dot_products = X @ Y.T # Apply polynomial kernel K = jnp.power(dot_products + coef0, degree) return K
# ============================================================================ # Pairwise Distance Functions (for single pairs) # ============================================================================
[docs] @jit def euclidean_distance(x: chex.Array, y: chex.Array) -> chex.Array: """ Euclidean distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) Returns: Scalar distance Example: >>> x = jnp.array([0, 0, 0]) >>> y = jnp.array([1, 1, 1]) >>> d = euclidean_distance(x, y) >>> d Array(1.732..., dtype=float32) """ chex.assert_equal_shape([x, y]) return jnp.sqrt(jnp.sum((x - y) ** 2))
[docs] @jit def squared_euclidean_distance(x: chex.Array, y: chex.Array) -> chex.Array: """ Squared Euclidean distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) Returns: Scalar squared distance """ chex.assert_equal_shape([x, y]) return jnp.sum((x - y) ** 2)
[docs] @jit def manhattan_distance(x: chex.Array, y: chex.Array) -> chex.Array: """ Manhattan (L1) distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) Returns: Scalar distance """ chex.assert_equal_shape([x, y]) return jnp.sum(jnp.abs(x - y))
[docs] def lpnorm_distance(x: chex.Array, y: chex.Array, p: float = 2) -> chex.Array: """ Lp-norm distance between two vectors. Args: x: Array of shape (d,) y: Array of shape (d,) p: Order of the norm (supports inf) Returns: Scalar distance """ chex.assert_equal_shape([x, y]) return jnp.linalg.norm(x - y, ord=p)
[docs] @jit def omega_distance(x: chex.Array, y: chex.Array, omega: chex.Array) -> chex.Array: """ Omega (projection-based) distance between two vectors. Computes :math:`\|\text{diff} \cdot \Omega\|^2` where :math:`\text{diff} = x - y`. Args: x: Array of shape (d,) y: Array of shape (d,) omega: Projection matrix of shape (d, k) Returns: Scalar squared distance in projected space """ chex.assert_equal_shape([x, y]) diff = x - y projected = diff @ omega return jnp.sum(projected ** 2)
[docs] def lomega_distance(X: chex.Array, Y: chex.Array, omegas: chex.Array) -> chex.Array: """ Local omega distance with per-prototype projection matrices. Args: X: Array of shape (n, d) Y: Array of shape (m, d) — prototypes omegas: Array of shape (m, d, k) — one projection matrix per prototype Returns: Distance matrix of shape (n, m) """ def compute_single(x, y, omega): diff = x - y projected = diff @ omega return jnp.sum(projected ** 2) def compute_row(x): return jax.vmap(compute_single, in_axes=(None, 0, 0))(x, Y, omegas) return jax.vmap(compute_row)(X)
# ============================================================================ # Utility Functions # ============================================================================
[docs] def estimate_sigma(X: chex.Array, percentile: float = 50.0) -> float: """ Estimate sigma for Gaussian kernel using pairwise distances. Strategy: Use median (or other percentile) of pairwise distances. Args: X: Data array of shape (n, d) percentile: Percentile of distances to use (0-100) Returns: sigma: Estimated bandwidth parameter Example: >>> X = jax.random.normal(jax.random.PRNGKey(0), (100, 10)) >>> sigma = estimate_sigma(X, percentile=50) Notes: - Heuristic: :math:`\sigma = \text{median\_distance} / \sqrt{2 \cdot n_\text{clusters}}` - For large datasets, use subsample to avoid O(n²) computation """ # For large datasets, subsample n = X.shape[0] if n > 1000: key = jax.random.PRNGKey(0) indices = jax.random.choice(key, n, shape=(1000,), replace=False) X_sub = X[indices] else: X_sub = X # Compute pairwise distances D = euclidean_distance_matrix(X_sub, X_sub) # Get upper triangle (exclude diagonal and duplicates) mask = jnp.triu(jnp.ones_like(D, dtype=bool), k=1) distances = D[mask] # Compute percentile sigma = jnp.percentile(distances, percentile) return float(sigma)
[docs] @jit def safe_divide(numerator: chex.Array, denominator: chex.Array, epsilon: float = 1e-10) -> chex.Array: """ Safe division avoiding division by zero. Args: numerator: Numerator array denominator: Denominator array epsilon: Small value to add to denominator Returns: numerator / (denominator + epsilon) Example: >>> x = jnp.array([1.0, 2.0, 3.0]) >>> y = jnp.array([2.0, 0.0, 1.0]) >>> safe_divide(x, y) Array([0.5, 2e+09, 3.0], dtype=float32) # Avoids inf """ return numerator / (denominator + epsilon)
# ============================================================================ # Module Information # ============================================================================ # Aliases for convenience batch_squared_euclidean = squared_euclidean_distance_matrix batch_euclidean = euclidean_distance_matrix __all__ = [ # Matrix distance functions 'euclidean_distance_matrix', 'squared_euclidean_distance_matrix', 'manhattan_distance_matrix', 'lpnorm_distance_matrix', 'omega_distance_matrix', 'lomega_distance_matrix', 'tangent_distance_matrix', # Wasserstein distance functions 'wasserstein2_distance_matrix', 'wasserstein2_omega_distance_matrix', 'wasserstein2_relevance_distance_matrix', # Kernel functions 'gaussian_kernel_matrix', 'polynomial_kernel_matrix', # Pairwise functions 'euclidean_distance', 'squared_euclidean_distance', 'manhattan_distance', 'lpnorm_distance', 'omega_distance', 'lomega_distance', # Utilities 'estimate_sigma', 'safe_divide', # Aliases 'batch_squared_euclidean', 'batch_euclidean', ]