"""
Kohonen Self-Organizing Map (standard textbook algorithm).
This implements the standard Kohonen SOM with Gaussian neighborhood
and exponential decay, distinct from prosemble's existing SOM.
References
----------
.. [1] Kohonen, T. (1990). The Self-Organizing Map. Proc. IEEE.
"""
from typing import NamedTuple
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, lax
from prosemble.models.prototype_base import UnsupervisedPrototypeModel
from prosemble.core.distance import squared_euclidean_distance_matrix
class SOMState(NamedTuple):
"""State for Kohonen SOM lax.scan loop."""
prototypes: jnp.ndarray
loss: jnp.ndarray
prev_loss: jnp.ndarray
converged: jnp.ndarray
iteration: jnp.ndarray
[docs]
class KohonenSOM(UnsupervisedPrototypeModel):
"""Standard Kohonen Self-Organizing Map.
Uses squared Euclidean distance for BMU selection, Gaussian
neighborhood function, exponential decay for sigma and learning
rate, and batch updates.
Parameters
----------
grid_height : int
Height of the 2D grid.
grid_width : int
Width of the 2D grid.
sigma_init : float, optional
Initial neighborhood radius. Default: max(grid_height, grid_width) / 2.
sigma_final : float
Final neighborhood radius.
lr_init : float
Initial learning rate.
lr_final : float
Final learning rate.
max_iter : int
Maximum training iterations.
lr : float
Initial learning rate.
epsilon : float
Convergence threshold.
random_seed : int
Random seed.
distance_fn : callable, optional
Distance function.
callbacks : list, optional
Callback objects.
use_scan : bool
If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore parameters from the lowest-loss epoch. Default: False.
"""
def __init__(self, grid_height=10, grid_width=10,
sigma_init=None, sigma_final=0.5,
lr_init=0.5, lr_final=0.01,
max_iter=100, lr=0.01, epsilon=1e-6, random_seed=42,
distance_fn=None, callbacks=None, use_scan=True,
patience=None, restore_best=False):
n_prototypes = grid_height * grid_width
super().__init__(
n_prototypes=n_prototypes, max_iter=max_iter, lr=lr,
epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn,
callbacks=callbacks, use_scan=use_scan, patience=patience,
restore_best=restore_best,
)
self.grid_height = grid_height
self.grid_width = grid_width
self.sigma_init = sigma_init # default: max(h, w) / 2
self.sigma_final = sigma_final
self.lr_init = lr_init
self.lr_final = lr_final
# Precompute grid positions
rows, cols = jnp.meshgrid(
jnp.arange(grid_height), jnp.arange(grid_width), indexing='ij'
)
self._grid_positions = jnp.stack([rows.ravel(), cols.ravel()], axis=1).astype(jnp.float32)
@partial(jit, static_argnums=(0,))
def _som_step(self, state, X, grid_dist_sq, sigma_init):
"""Single JIT-compiled Kohonen SOM training step."""
t = state.iteration
max_t = jnp.array(max(self.max_iter - 1, 1), dtype=jnp.float32)
frac = t.astype(jnp.float32) / max_t
sigma_t = sigma_init * (self.sigma_final / sigma_init) ** frac
lr_t = self.lr_init * (self.lr_final / self.lr_init) ** frac
prototypes = state.prototypes
n_samples = X.shape[0]
# Find BMU
distances = squared_euclidean_distance_matrix(X, prototypes)
bmu_indices = jnp.argmin(distances, axis=1)
# Gaussian neighborhood
bmu_grid_dist_sq = grid_dist_sq[bmu_indices]
h = jnp.exp(-bmu_grid_dist_sq / (2.0 * sigma_t ** 2))
# Batch update
diffs = X[:, None, :] - prototypes[None, :, :]
weighted_diffs = h[:, :, None] * diffs
numerator = jnp.sum(weighted_diffs, axis=0)
denominator = jnp.sum(h, axis=0)[:, None]
update = lr_t * numerator / (denominator + 1e-10)
new_prototypes = prototypes + update
# Quantization error
bmu_dists = distances[jnp.arange(n_samples), bmu_indices]
qe = jnp.mean(bmu_dists)
# Convergence
has_converged = state.converged | (
jnp.abs(qe - state.prev_loss) < self.epsilon
)
frozen_prototypes = jnp.where(state.converged, prototypes, new_prototypes)
frozen_qe = jnp.where(state.converged, state.loss, qe)
new_state = SOMState(
prototypes=frozen_prototypes,
loss=frozen_qe,
prev_loss=qe,
converged=has_converged,
iteration=t + 1,
)
return new_state, frozen_qe
@partial(jit, static_argnums=(0,))
def _fit_scan(self, X, prototypes, grid_dist_sq, sigma_init):
"""Scan-based training loop."""
initial_state = SOMState(
prototypes=prototypes,
loss=jnp.array(float('inf')),
prev_loss=jnp.array(float('inf')),
converged=jnp.array(False),
iteration=jnp.array(0),
)
def scan_fn(state, _):
return self._som_step(state, X, grid_dist_sq, sigma_init)
final_state, loss_history = lax.scan(
scan_fn, initial_state, None, length=self.max_iter
)
return final_state, loss_history
[docs]
def fit(self, X):
"""Fit KohonenSOM."""
X = jnp.asarray(X, dtype=jnp.float32)
n_samples = X.shape[0]
key = self.key
indices = jax.random.choice(key, n_samples, (self.n_prototypes,), replace=False)
prototypes = X[indices]
sigma_init_val = self.sigma_init if self.sigma_init else max(self.grid_height, self.grid_width) / 2.0
# Precompute grid distances
grid_pos = self._grid_positions
grid_dist_sq = jnp.sum(
(grid_pos[:, None, :] - grid_pos[None, :, :]) ** 2, axis=2
)
if self.use_scan and self.patience is None and not self.restore_best:
return self._fit_with_scan(X, prototypes, grid_dist_sq, sigma_init_val)
else:
return self._fit_with_python_loop(X, prototypes, grid_dist_sq, sigma_init_val)
def _fit_with_scan(self, X, prototypes, grid_dist_sq, sigma_init_val):
"""lax.scan training: JIT-compiled, runs all max_iter iterations."""
sigma_init = jnp.array(sigma_init_val, dtype=jnp.float32)
final_state, loss_history = self._fit_scan(X, prototypes, grid_dist_sq, sigma_init)
converged_mask = jnp.abs(jnp.diff(loss_history)) < self.epsilon
first_converged = jnp.argmax(converged_mask)
has_any = jnp.any(converged_mask)
n_iter = jnp.where(has_any, first_converged + 2, self.max_iter)
self.prototypes_ = final_state.prototypes
self.n_iter_ = int(n_iter)
self.loss_ = float(final_state.loss)
self.loss_history_ = loss_history
return self
def _fit_with_python_loop(self, X, prototypes, grid_dist_sq, sigma_init_val):
"""Python for-loop training: true early stopping, no wasted compute."""
n_samples = X.shape[0]
loss_history = []
best_loss = None
best_prototypes = None
for t in range(self.max_iter):
frac = t / max(self.max_iter - 1, 1)
sigma_t = sigma_init_val * (self.sigma_final / sigma_init_val) ** frac
lr_t = self.lr_init * (self.lr_final / self.lr_init) ** frac
distances = squared_euclidean_distance_matrix(X, prototypes)
bmu_indices = jnp.argmin(distances, axis=1)
bmu_grid_dist_sq = grid_dist_sq[bmu_indices]
h = jnp.exp(-bmu_grid_dist_sq / (2.0 * sigma_t ** 2))
diffs = X[:, None, :] - prototypes[None, :, :]
weighted_diffs = h[:, :, None] * diffs
numerator = jnp.sum(weighted_diffs, axis=0)
denominator = jnp.sum(h, axis=0)[:, None]
update = lr_t * numerator / (denominator + 1e-10)
prototypes = prototypes + update
bmu_dists = distances[jnp.arange(n_samples), bmu_indices]
qe = float(jnp.mean(bmu_dists))
loss_history.append(qe)
if self.restore_best and (best_loss is None or qe < best_loss):
best_loss = qe
best_prototypes = prototypes
if t > 0 and abs(loss_history[-1] - loss_history[-2]) < self.epsilon:
break
if self.patience is not None and self._check_patience(loss_history, self.patience):
break
if self.restore_best and best_prototypes is not None:
prototypes = best_prototypes
self.best_loss_ = best_loss
self.prototypes_ = prototypes
self.n_iter_ = t + 1
self.loss_ = loss_history[-1]
self.loss_history_ = jnp.array(loss_history)
return self
def bmu_map(self, X):
"""Return BMU grid coordinates for each sample.
Parameters
----------
X : array of shape (n, d)
Returns
-------
coords : array of shape (n, 2) — (row, col) for each sample
"""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
distances = squared_euclidean_distance_matrix(X, self.prototypes_)
bmu_indices = jnp.argmin(distances, axis=1)
return self._grid_positions[bmu_indices]
def _get_hyperparams(self):
hp = super()._get_hyperparams()
hp.update({
'grid_height': self.grid_height,
'grid_width': self.grid_width,
'sigma_final': self.sigma_final,
'lr_init': self.lr_init,
'lr_final': self.lr_final,
})
if self.sigma_init is not None:
hp['sigma_init'] = self.sigma_init
return hp