Source code for prosemble.models.heskes_som

"""
Heskes Self-Organizing Map.

Unlike the standard Kohonen SOM, the Heskes SOM uses a modified
Best Matching Unit (BMU) definition that accounts for the neighborhood
function, and a pure batch update rule (no learning rate). This
guarantees monotonic decrease of a well-defined energy function.

Energy function:

.. math::

    E = \\sum_x \\sum_k h(k, c^*(x)) \\cdot \\|x - w_k\\|^2

Modified BMU:

.. math::

    c^*(x) = \\arg\\min_c \\sum_k h(k, c) \\cdot \\|x - w_k\\|^2

Batch update:

.. math::

    w_k = \\frac{\\sum_x h(k, c^*(x)) \\cdot x}{\\sum_x h(k, c^*(x))}

References
----------
.. [1] Heskes, T. (1999). Energy functions for self-organizing maps.
       In Kohonen Maps, pp. 303-316, Elsevier.
.. [2] Heskes, T. (2001). Self-organizing maps, vector quantization,
       and mixture modeling. IEEE Trans. Neural Networks, 12(6).
"""

from typing import NamedTuple
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit, lax

from prosemble.models.prototype_base import UnsupervisedPrototypeModel
from prosemble.core.distance import squared_euclidean_distance_matrix


class HeskesSOMState(NamedTuple):
    """State for Heskes SOM lax.scan loop."""
    prototypes: jnp.ndarray
    loss: jnp.ndarray
    prev_loss: jnp.ndarray
    converged: jnp.ndarray
    iteration: jnp.ndarray


[docs] class HeskesSOM(UnsupervisedPrototypeModel): """Heskes Self-Organizing Map. Uses a modified BMU definition that considers the neighborhood structure, and a pure batch update (weighted average of data). Guarantees monotonic decrease of the Heskes energy function. Differences from KohonenSOM: - BMU is chosen to minimize neighborhood-weighted distance sum, not raw distance to closest prototype. - Prototypes are updated via weighted average (no learning rate). - Energy is guaranteed to decrease monotonically. Parameters ---------- grid_height : int Height of the 2D grid. grid_width : int Width of the 2D grid. sigma_init : float, optional Initial neighborhood radius. Default: max(grid_height, grid_width) / 2. sigma_final : float Final neighborhood radius. max_iter : int Maximum training iterations. lr : float Initial learning rate. epsilon : float Convergence threshold. random_seed : int Random seed. distance_fn : callable, optional Distance function. callbacks : list, optional Callback objects. use_scan : bool If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore parameters from the lowest-loss epoch. Default: False. """ def __init__(self, grid_height=10, grid_width=10, sigma_init=None, sigma_final=0.5, max_iter=100, lr=0.01, epsilon=1e-6, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False): n_prototypes = grid_height * grid_width super().__init__( n_prototypes=n_prototypes, max_iter=max_iter, lr=lr, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, callbacks=callbacks, use_scan=use_scan, patience=patience, restore_best=restore_best, ) self.grid_height = grid_height self.grid_width = grid_width self.sigma_init = sigma_init self.sigma_final = sigma_final # Precompute grid positions rows, cols = jnp.meshgrid( jnp.arange(grid_height), jnp.arange(grid_width), indexing='ij' ) self._grid_positions = jnp.stack( [rows.ravel(), cols.ravel()], axis=1 ).astype(jnp.float32) @partial(jit, static_argnums=(0,)) def _heskes_step(self, state, X, grid_dist_sq, sigma_init): """Single JIT-compiled Heskes SOM training step.""" t = state.iteration max_t = jnp.array(max(self.max_iter - 1, 1), dtype=jnp.float32) frac = t.astype(jnp.float32) / max_t sigma_t = sigma_init * (self.sigma_final / sigma_init) ** frac prototypes = state.prototypes n_protos = prototypes.shape[0] n_samples = X.shape[0] # Squared distances: (n_samples, n_protos) distances = squared_euclidean_distance_matrix(X, prototypes) # Neighborhood matrix: h(k, c) for all pairs — (n_protos, n_protos) h_matrix = jnp.exp(-grid_dist_sq / (2.0 * sigma_t ** 2)) # Heskes BMU: c*(x) = argmin_c Σ_k h(k,c) * ||x - w_k||^2 # For each sample x, compute Σ_k h(k,c) * d(x,k) for each candidate c # distances: (n_samples, n_protos=k), h_matrix: (k, c) # weighted_cost: (n_samples, c) weighted_cost = jnp.dot(distances, h_matrix) # (n_samples, n_protos) bmu_indices = jnp.argmin(weighted_cost, axis=1) # (n_samples,) # Neighborhood weights for each sample: h(k, c*(x)) # bmu_indices: (n_samples,) -> h_matrix[bmu_indices]: would be wrong axis # h_matrix is (k, c), we need h(k, c*(x)) for each x # h_matrix[:, bmu_indices] -> (n_protos=k, n_samples) -> transpose h_weights = h_matrix[:, bmu_indices].T # (n_samples, n_protos=k) # Batch update: w_k = Σ_x h(k, c*(x)) * x / Σ_x h(k, c*(x)) # h_weights: (n_samples, k), X: (n_samples, d) # numerator: (k, d) = h_weights.T @ X numerator = jnp.dot(h_weights.T, X) # (n_protos, n_features) denominator = jnp.sum(h_weights, axis=0)[:, None] # (n_protos, 1) new_prototypes = numerator / (denominator + 1e-10) # Heskes energy: E = Σ_x Σ_k h(k, c*(x)) * ||x - w_k||^2 energy = jnp.sum(h_weights * distances) # Convergence has_converged = state.converged | ( jnp.abs(energy - state.prev_loss) < self.epsilon ) frozen_prototypes = jnp.where( state.converged, prototypes, new_prototypes ) frozen_energy = jnp.where(state.converged, state.loss, energy) new_state = HeskesSOMState( prototypes=frozen_prototypes, loss=frozen_energy, prev_loss=energy, converged=has_converged, iteration=t + 1, ) return new_state, frozen_energy @partial(jit, static_argnums=(0,)) def _fit_scan(self, X, prototypes, grid_dist_sq, sigma_init): """Scan-based training loop.""" initial_state = HeskesSOMState( prototypes=prototypes, loss=jnp.array(float('inf')), prev_loss=jnp.array(float('inf')), converged=jnp.array(False), iteration=jnp.array(0), ) def scan_fn(state, _): return self._heskes_step(state, X, grid_dist_sq, sigma_init) final_state, loss_history = lax.scan( scan_fn, initial_state, None, length=self.max_iter ) return final_state, loss_history
[docs] def fit(self, X): """Fit HeskesSOM.""" X = jnp.asarray(X, dtype=jnp.float32) n_samples = X.shape[0] key = self.key indices = jax.random.choice( key, n_samples, (self.n_prototypes,), replace=False ) prototypes = X[indices] sigma_init_val = ( self.sigma_init if self.sigma_init else max(self.grid_height, self.grid_width) / 2.0 ) # Precompute grid distances grid_pos = self._grid_positions grid_dist_sq = jnp.sum( (grid_pos[:, None, :] - grid_pos[None, :, :]) ** 2, axis=2 ) if self.use_scan and self.patience is None and not self.restore_best: return self._fit_with_scan( X, prototypes, grid_dist_sq, sigma_init_val ) else: return self._fit_with_python_loop( X, prototypes, grid_dist_sq, sigma_init_val )
def _fit_with_scan(self, X, prototypes, grid_dist_sq, sigma_init_val): """lax.scan training.""" sigma_init = jnp.array(sigma_init_val, dtype=jnp.float32) final_state, loss_history = self._fit_scan( X, prototypes, grid_dist_sq, sigma_init ) converged_mask = jnp.abs(jnp.diff(loss_history)) < self.epsilon first_converged = jnp.argmax(converged_mask) has_any = jnp.any(converged_mask) n_iter = jnp.where(has_any, first_converged + 2, self.max_iter) self.prototypes_ = final_state.prototypes self.n_iter_ = int(n_iter) self.loss_ = float(final_state.loss) self.loss_history_ = loss_history return self def _fit_with_python_loop(self, X, prototypes, grid_dist_sq, sigma_init_val): """Python for-loop training with true early stopping.""" n_samples = X.shape[0] loss_history = [] best_loss = None best_prototypes = None # Neighborhood matrix: h(k, c) for t in range(self.max_iter): frac = t / max(self.max_iter - 1, 1) sigma_t = sigma_init_val * ( self.sigma_final / sigma_init_val ) ** frac distances = squared_euclidean_distance_matrix(X, prototypes) h_matrix = jnp.exp(-grid_dist_sq / (2.0 * sigma_t ** 2)) # Heskes BMU weighted_cost = jnp.dot(distances, h_matrix) bmu_indices = jnp.argmin(weighted_cost, axis=1) # Neighborhood weights h_weights = h_matrix[:, bmu_indices].T # Batch update numerator = jnp.dot(h_weights.T, X) denominator = jnp.sum(h_weights, axis=0)[:, None] prototypes = numerator / (denominator + 1e-10) # Energy energy = float(jnp.sum(h_weights * distances)) loss_history.append(energy) if self.restore_best and (best_loss is None or energy < best_loss): best_loss = energy best_prototypes = prototypes if t > 0 and abs(loss_history[-1] - loss_history[-2]) < self.epsilon: break if self.patience is not None and self._check_patience(loss_history, self.patience): break if self.restore_best and best_prototypes is not None: prototypes = best_prototypes self.best_loss_ = best_loss self.prototypes_ = prototypes self.n_iter_ = t + 1 self.loss_ = loss_history[-1] self.loss_history_ = jnp.array(loss_history) return self def bmu_map(self, X): """Return BMU grid coordinates using Heskes criterion. Parameters ---------- X : array of shape (n, d) Returns ------- coords : array of shape (n, 2) — (row, col) for each sample """ self._check_fitted() X = jnp.asarray(X, dtype=jnp.float32) distances = squared_euclidean_distance_matrix(X, self.prototypes_) grid_pos = self._grid_positions grid_dist_sq = jnp.sum( (grid_pos[:, None, :] - grid_pos[None, :, :]) ** 2, axis=2 ) # Use a small sigma for inference (tight neighborhood) h_matrix = jnp.exp(-grid_dist_sq / (2.0 * self.sigma_final ** 2)) weighted_cost = jnp.dot(distances, h_matrix) bmu_indices = jnp.argmin(weighted_cost, axis=1) return self._grid_positions[bmu_indices] def _get_hyperparams(self): hp = super()._get_hyperparams() hp.update({ 'grid_height': self.grid_height, 'grid_width': self.grid_width, 'sigma_final': self.sigma_final, }) if self.sigma_init is not None: hp['sigma_init'] = self.sigma_init return hp