"""
JAX-based kernel functions for kernel clustering methods.
This module provides GPU-accelerated kernel computations using JAX.
"""
import jax.numpy as jnp
import chex
from jax import jit
from functools import partial
[docs]
@jit
def gaussian_kernel(x: chex.Array, y: chex.Array, sigma: float) -> chex.Array:
"""
Compute Gaussian (RBF) kernel between two vectors.
:math:`K(x, y) = \exp(-\|x - y\|^2 / (2\sigma^2))`
Args:
x: First vector, shape (n_features,)
y: Second vector, shape (n_features,)
sigma: Kernel bandwidth parameter
Returns:
Kernel value (scalar)
"""
diff = x - y
sq_norm = jnp.sum(diff * diff)
return jnp.exp(-sq_norm / (2.0 * sigma ** 2))
[docs]
@jit
def batch_gaussian_kernel(
X: chex.Array, Y: chex.Array, sigma: float
) -> chex.Array:
"""
Compute Gaussian kernel between two sets of vectors.
Args:
X: First set of vectors, shape (n_samples, n_features)
Y: Second set of vectors, shape (m_samples, n_features)
sigma: Kernel bandwidth parameter
Returns:
Kernel matrix, shape (n_samples, m_samples)
K[i, j] = K(X[i], Y[j])
"""
# Compute squared Euclidean distances
X_sq = jnp.sum(X ** 2, axis=1, keepdims=True) # (n_samples, 1)
Y_sq = jnp.sum(Y ** 2, axis=1, keepdims=True) # (m_samples, 1)
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2<x, y>
sq_distances = X_sq + Y_sq.T - 2.0 * (X @ Y.T) # (n_samples, m_samples)
# Ensure non-negative (numerical stability)
sq_distances = jnp.maximum(sq_distances, 0.0)
# Compute kernel
K = jnp.exp(-sq_distances / (2.0 * sigma ** 2))
return K
[docs]
@jit
def kernel_distance_squared(
X: chex.Array, Y: chex.Array, sigma: float
) -> chex.Array:
"""
Compute squared distance in feature space.
For Gaussian kernel: ||phi(x) - phi(y)||^2 = K(x,x) + K(y,y) - 2K(x,y) = 2(1 - K(x,y)),
since K(x,x) = 1.
Args:
X: First set of vectors, shape (n_samples, n_features)
Y: Second set of vectors, shape (m_samples, n_features)
sigma: Kernel bandwidth parameter
Returns:
Squared distances in feature space, shape (n_samples, m_samples)
"""
K = batch_gaussian_kernel(X, Y, sigma)
return 2.0 * (1.0 - K)
[docs]
@jit
def kernel_distance(
X: chex.Array, Y: chex.Array, sigma: float
) -> chex.Array:
"""
Compute distance in feature space.
Args:
X: First set of vectors, shape (n_samples, n_features)
Y: Second set of vectors, shape (m_samples, n_features)
sigma: Kernel bandwidth parameter
Returns:
Distances in feature space, shape (n_samples, m_samples)
"""
D_sq = kernel_distance_squared(X, Y, sigma)
return jnp.sqrt(jnp.maximum(D_sq, 0.0))