Source code for prosemble.models.dk_heskes_som

"""
Differentiating Kernel Heskes SOM (DKHeskesSOM).

Heskes SOM with Gaussian kernel distance:

.. math::

    d_\\kappa^2(x, w_k) = 2\\left(1 - \\exp\\left(
        -\\frac{\\|x - w_k\\|^2}{2\\sigma^2}
    \\right)\\right)

The kernel bandwidth :math:`\\sigma` is a fixed hyperparameter.
The Heskes BMU criterion and batch update remain unchanged.
Prototypes live in the original data space.

References
----------
.. [1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector
       quantization in gradient-descent learning. Neurocomputing.
"""

from functools import partial

import jax.numpy as jnp
from jax import jit

from prosemble.models.heskes_som import HeskesSOM, HeskesSOMState
from prosemble.core.kernel import kernel_distance_squared_per_proto


[docs] class DKHeskesSOM(HeskesSOM): """Differentiating Kernel Heskes SOM. Heskes SOM with Gaussian kernel distance. The kernel bandwidth :math:`\\sigma` is a fixed hyperparameter (not learned). The Heskes BMU criterion and batch update operate in the original data space — only the data-space distance metric changes. The Heskes BMU criterion uses kernel distance: .. math:: c^*(x) = \\arg\\min_c \\sum_k h(k, c) \\cdot d_\\kappa^2(x, w_k) Parameters ---------- kernel_sigma : float Gaussian kernel bandwidth for data-space distance. Default: 1.0. grid_height : int Height of the 2D grid. grid_width : int Width of the 2D grid. sigma_init : float, optional Initial grid neighborhood radius. sigma_final : float Final grid neighborhood radius. max_iter : int Maximum training iterations. epsilon : float Convergence threshold. random_seed : int Random seed. callbacks : list, optional Callback objects. use_scan : bool If True (default), use jax.lax.scan for training. patience : int, optional Epochs with no improvement before early stopping. restore_best : bool If True, restore best parameters after training. References ---------- .. [1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing. See Also -------- HeskesSOM : Base class with Euclidean distance. """ def __init__(self, grid_height=10, grid_width=10, kernel_sigma=1.0, sigma_init=None, sigma_final=0.5, max_iter=100, lr=0.01, epsilon=1e-6, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False): super().__init__( grid_height=grid_height, grid_width=grid_width, sigma_init=sigma_init, sigma_final=sigma_final, max_iter=max_iter, lr=lr, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, callbacks=callbacks, use_scan=use_scan, patience=patience, restore_best=restore_best, ) self.kernel_sigma = kernel_sigma def _kernel_distances(self, X, prototypes): """Compute kernel distances with broadcast sigma.""" sigmas = jnp.full(prototypes.shape[0], self.kernel_sigma) return kernel_distance_squared_per_proto(X, prototypes, sigmas) @partial(jit, static_argnums=(0,)) def _heskes_step(self, state, X, grid_dist_sq, sigma_init): """Single JIT-compiled DK Heskes SOM training step.""" t = state.iteration max_t = jnp.array(max(self.max_iter - 1, 1), dtype=jnp.float32) frac = t.astype(jnp.float32) / max_t sigma_t = sigma_init * (self.sigma_final / sigma_init) ** frac prototypes = state.prototypes # Kernel distances: (n_samples, n_protos) distances = self._kernel_distances(X, prototypes) # Neighborhood matrix: h(k, c) for all pairs — (n_protos, n_protos) h_matrix = jnp.exp(-grid_dist_sq / (2.0 * sigma_t ** 2)) # Heskes BMU: c*(x) = argmin_c Σ_k h(k,c) * d_κ(x,k) weighted_cost = jnp.dot(distances, h_matrix) # (n_samples, n_protos) bmu_indices = jnp.argmin(weighted_cost, axis=1) # Neighborhood weights for each sample: h(k, c*(x)) h_weights = h_matrix[:, bmu_indices].T # (n_samples, n_protos) # Batch update: w_k = Σ_x h(k,c*(x)) * x / Σ_x h(k,c*(x)) numerator = jnp.dot(h_weights.T, X) # (n_protos, n_features) denominator = jnp.sum(h_weights, axis=0)[:, None] new_prototypes = numerator / (denominator + 1e-10) # Heskes energy: E = Σ_x Σ_k h(k,c*(x)) * d_κ(x,k) energy = jnp.sum(h_weights * distances) # Convergence has_converged = state.converged | ( jnp.abs(energy - state.prev_loss) < self.epsilon ) frozen_prototypes = jnp.where( state.converged, prototypes, new_prototypes ) frozen_energy = jnp.where(state.converged, state.loss, energy) new_state = HeskesSOMState( prototypes=frozen_prototypes, loss=frozen_energy, prev_loss=energy, converged=has_converged, iteration=t + 1, ) return new_state, frozen_energy def _fit_with_python_loop(self, X, prototypes, grid_dist_sq, sigma_init_val): """Python for-loop training with kernel distance.""" loss_history = [] best_loss = None best_prototypes = None for t in range(self.max_iter): frac = t / max(self.max_iter - 1, 1) sigma_t = sigma_init_val * ( self.sigma_final / sigma_init_val ) ** frac distances = self._kernel_distances(X, prototypes) h_matrix = jnp.exp(-grid_dist_sq / (2.0 * sigma_t ** 2)) # Heskes BMU weighted_cost = jnp.dot(distances, h_matrix) bmu_indices = jnp.argmin(weighted_cost, axis=1) # Neighborhood weights h_weights = h_matrix[:, bmu_indices].T # Batch update numerator = jnp.dot(h_weights.T, X) denominator = jnp.sum(h_weights, axis=0)[:, None] prototypes = numerator / (denominator + 1e-10) # Energy energy = float(jnp.sum(h_weights * distances)) loss_history.append(energy) if self.restore_best and (best_loss is None or energy < best_loss): best_loss = energy best_prototypes = prototypes if t > 0 and abs(loss_history[-1] - loss_history[-2]) < self.epsilon: break if self.patience is not None and self._check_patience(loss_history, self.patience): break if self.restore_best and best_prototypes is not None: prototypes = best_prototypes self.best_loss_ = best_loss self.prototypes_ = prototypes self.n_iter_ = t + 1 self.loss_ = loss_history[-1] self.loss_history_ = jnp.array(loss_history) return self
[docs] def bmu_map(self, X): """Return BMU grid coordinates using Heskes criterion with kernel distance. Parameters ---------- X : array of shape (n, d) Returns ------- coords : array of shape (n, 2) — (row, col) for each sample """ self._check_fitted() X = jnp.asarray(X, dtype=jnp.float32) distances = self._kernel_distances(X, self.prototypes_) grid_pos = self._grid_positions grid_dist_sq = jnp.sum( (grid_pos[:, None, :] - grid_pos[None, :, :]) ** 2, axis=2 ) h_matrix = jnp.exp(-grid_dist_sq / (2.0 * self.sigma_final ** 2)) weighted_cost = jnp.dot(distances, h_matrix) bmu_indices = jnp.argmin(weighted_cost, axis=1) return self._grid_positions[bmu_indices]
def _get_hyperparams(self): hp = super()._get_hyperparams() hp['kernel_sigma'] = self.kernel_sigma return hp