Source code for prosemble.core.similarities

"""
Similarity functions for prototype-based learning.

Similarities are the dual of distances: higher values indicate
closer/more similar points.
"""

import jax.numpy as jnp
from jax import jit


[docs] @jit def gaussian_similarity(distances_sq, variance=1.0): """Convert squared distances to Gaussian similarities. :math:`s(d) = \exp(-d^2 / (2 \cdot \text{variance}))` Parameters ---------- distances_sq : array Squared distances. variance : float Variance (sigma^2) of the Gaussian. Returns ------- array Similarity values in (0, 1]. """ return jnp.exp(-distances_sq / (2.0 * variance))
[docs] @jit def cosine_similarity_matrix(X, Y): """Pairwise cosine similarity between rows of X and Y. :math:`\cos(x, y) = \frac{x \cdot y}{\|x\| \cdot \|y\|}` Parameters ---------- X : array of shape (n, d) Y : array of shape (m, d) Returns ------- array of shape (n, m) Cosine similarities in [-1, 1]. """ # Compute norms norm_X = jnp.linalg.norm(X, axis=1, keepdims=True) # (n, 1) norm_Y = jnp.linalg.norm(Y, axis=1, keepdims=True) # (m, 1) # Avoid division by zero eps = jnp.finfo(X.dtype).eps norm_X = jnp.maximum(norm_X, eps) norm_Y = jnp.maximum(norm_Y, eps) # Dot product matrix / outer product of norms dot_product = X @ Y.T # (n, m) norm_product = norm_X @ norm_Y.T # (n, m) return dot_product / norm_product
[docs] @jit def euclidean_similarity(X, Y, variance=1.0): """Pairwise Euclidean similarity (Gaussian of Euclidean distance). Parameters ---------- X : array of shape (n, d) Y : array of shape (m, d) variance : float Variance of the Gaussian kernel. Returns ------- array of shape (n, m) Similarity values in (0, 1]. """ diff = X[:, None, :] - Y[None, :, :] # (n, m, d) dist_sq = jnp.sum(diff ** 2, axis=2) # (n, m) return jnp.exp(-dist_sq / (2.0 * variance))
[docs] @jit def rank_scaled_gaussian(distances, lambd=1.0): """Rank-scaled Gaussian similarity. Combines distance magnitude with rank ordering: closer prototypes (lower rank) receive a stronger signal, while farther ones are exponentially suppressed. .. math:: s(d, r) = \exp(-\exp(-r / \lambda) \cdot d) where *r* is the rank of each distance (0 = closest). Parameters ---------- distances : array of shape (n, m) Distance matrix (non-negative). lambd : float Rank decay parameter. Larger values give more uniform weighting across ranks; smaller values concentrate on nearest neighbours. Returns ------- array of shape (n, m) Rank-scaled similarity values in (0, 1]. Notes ----- Used in Probabilistic LVQ (PLVQ) as a conditional distribution P(x|prototype). """ order = jnp.argsort(distances, axis=1) ranks = jnp.argsort(order, axis=1).astype(distances.dtype) return jnp.exp(-jnp.exp(-ranks / lambd) * distances)