Source code for prosemble.core.kernel

"""
JAX-based kernel functions for kernel clustering methods.

This module provides GPU-accelerated kernel computations using JAX.
"""

import jax.numpy as jnp
import chex
from jax import jit
from functools import partial


[docs] @jit def gaussian_kernel(x: chex.Array, y: chex.Array, sigma: float) -> chex.Array: """ Compute Gaussian (RBF) kernel between two vectors. :math:`K(x, y) = \exp(-\|x - y\|^2 / (2\sigma^2))` Args: x: First vector, shape (n_features,) y: Second vector, shape (n_features,) sigma: Kernel bandwidth parameter Returns: Kernel value (scalar) """ diff = x - y sq_norm = jnp.sum(diff * diff) return jnp.exp(-sq_norm / (2.0 * sigma ** 2))
[docs] @jit def batch_gaussian_kernel( X: chex.Array, Y: chex.Array, sigma: float ) -> chex.Array: """ Compute Gaussian kernel between two sets of vectors. Args: X: First set of vectors, shape (n_samples, n_features) Y: Second set of vectors, shape (m_samples, n_features) sigma: Kernel bandwidth parameter Returns: Kernel matrix, shape (n_samples, m_samples) K[i, j] = K(X[i], Y[j]) """ # Compute squared Euclidean distances X_sq = jnp.sum(X ** 2, axis=1, keepdims=True) # (n_samples, 1) Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True) # (m_samples, 1) # ||x - y||^2 = ||x||^2 + ||y||^2 - 2<x, y> sq_distances = X_sq + Y_sq.T - 2.0 * (X @ Y.T) # (n_samples, m_samples) # Ensure non-negative (numerical stability) sq_distances = jnp.maximum(sq_distances, 0.0) # Compute kernel K = jnp.exp(-sq_distances / (2.0 * sigma ** 2)) return K
[docs] @jit def kernel_distance_squared( X: chex.Array, Y: chex.Array, sigma: float ) -> chex.Array: """ Compute squared distance in feature space. For Gaussian kernel: :math:`\|\phi(x) - \phi(y)\|^2 = K(x,x) + K(y,y) - 2K(x,y) = 2(1 - K(x,y))`, since :math:`K(x,x) = 1`. Args: X: First set of vectors, shape (n_samples, n_features) Y: Second set of vectors, shape (m_samples, n_features) sigma: Kernel bandwidth parameter Returns: Squared distances in feature space, shape (n_samples, m_samples) """ K = batch_gaussian_kernel(X, Y, sigma) return 2.0 * (1.0 - K)
[docs] @jit def kernel_distance( X: chex.Array, Y: chex.Array, sigma: float ) -> chex.Array: """ Compute distance in feature space. Args: X: First set of vectors, shape (n_samples, n_features) Y: Second set of vectors, shape (m_samples, n_features) sigma: Kernel bandwidth parameter Returns: Distances in feature space, shape (n_samples, m_samples) """ D_sq = kernel_distance_squared(X, Y, sigma) return jnp.sqrt(jnp.maximum(D_sq, 0.0))
[docs] @jit def kernel_distance_squared_per_proto( X: chex.Array, W: chex.Array, sigmas: chex.Array ) -> chex.Array: """ Squared kernel distance with per-prototype bandwidth. .. math:: d_\\kappa^2(x, w_k) = 2\\left(1 - \\exp\\left(-\\frac{\\|x - w_k\\|^2}{2\\sigma_k^2}\\right)\\right) Each prototype :math:`w_k` has its own bandwidth :math:`\\sigma_k`. Args: X: Data matrix, shape (n_samples, n_features). W: Prototype matrix, shape (n_prototypes, n_features). sigmas: Per-prototype bandwidths, shape (n_prototypes,). Returns: Squared distances in feature space, shape (n_samples, n_prototypes). References ---------- .. [1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing. """ diff = X[:, None, :] - W[None, :, :] # (n, p, d) sq_norms = jnp.sum(diff ** 2, axis=2) # (n, p) K = jnp.exp(-sq_norms / (2.0 * sigmas[None, :] ** 2)) # (n, p) return 2.0 * (1.0 - K)
[docs] @jit def kernel_distance_squared_relevance( X: chex.Array, W: chex.Array, sigmas: chex.Array, relevances: chex.Array ) -> chex.Array: """ Squared kernel distance with relevance weighting and per-prototype bandwidth. .. math:: d_\\kappa^2(x, w_k) = 2\\left(1 - \\exp\\left(-\\frac{\\sum_j \\lambda_j (x_j - w_{kj})^2}{2\\sigma_k^2}\\right)\\right) Args: X: Data matrix, shape (n_samples, n_features). W: Prototype matrix, shape (n_prototypes, n_features). sigmas: Per-prototype bandwidths, shape (n_prototypes,). relevances: Normalized relevance weights, shape (n_features,). Returns: Squared distances in feature space, shape (n_samples, n_prototypes). References ---------- .. [1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing. """ diff = X[:, None, :] - W[None, :, :] # (n, p, d) weighted_sq = jnp.sum(relevances[None, None, :] * diff ** 2, axis=2) # (n, p) K = jnp.exp(-weighted_sq / (2.0 * sigmas[None, :] ** 2)) # (n, p) return 2.0 * (1.0 - K)
[docs] @jit def exponential_kernel_distance_squared( X: chex.Array, W: chex.Array, omega_hat: chex.Array ) -> chex.Array: """ Squared distance in exponential kernel feature space. Uses the exponential kernel :math:`\\kappa_{\\exp}(v, w, \\hat\\Lambda) = \\exp(v^T \\hat\\Lambda w)` where :math:`\\hat\\Lambda = \\hat\\Omega \\hat\\Omega^T`. .. math:: d_\\kappa^2(x, w) = \\exp(x^T \\hat\\Lambda x) + \\exp(w^T \\hat\\Lambda w) - 2 \\exp(x^T \\hat\\Lambda w) Note: :math:`\\kappa(v, v) \\neq 1` for the exponential kernel, so the full three-term formula is required (not the 2(1-K) simplification). Args: X: Data matrix, shape (n_samples, n_features). W: Prototype matrix, shape (n_prototypes, n_features). omega_hat: Transformation matrix, shape (n_features, latent_dim). The kernel matrix is :math:`\\hat\\Lambda = \\hat\\Omega \\hat\\Omega^T`. Returns: Squared distances in feature space, shape (n_samples, n_prototypes). References ---------- .. [1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing. """ # Λ̂ = Ω̂ Ω̂^T (d, d) lambda_hat = jnp.dot(omega_hat, omega_hat.T) # x^T Λ̂ x for all samples: (n,) Lx = jnp.dot(X, lambda_hat) # (n, d) xLx = jnp.sum(X * Lx, axis=1) # (n,) # w^T Λ̂ w for all prototypes: (p,) Lw = jnp.dot(W, lambda_hat) # (p, d) wLw = jnp.sum(W * Lw, axis=1) # (p,) # x^T Λ̂ w for all (n, p) pairs: (n, p) xLw = jnp.dot(Lx, W.T) # (n, p) # d_κ²(x, w) = exp(x^T Λ̂ x) + exp(w^T Λ̂ w) - 2·exp(x^T Λ̂ w) distances = (jnp.exp(xLx[:, None]) + jnp.exp(wLw[None, :]) - 2.0 * jnp.exp(xLw)) return jnp.maximum(distances, 0.0)