"""
Neural Gas algorithm.
Topology-preserving unsupervised learning with rank-based
neighborhood adaptation and exponential decay.
References
----------
.. [1] Martinetz, T. M., Berkovich, S. G., & Schulten, K. J. (1993).
"Neural-gas" network for vector quantization and its application
to time-series prediction. IEEE Trans. Neural Networks.
"""
from typing import NamedTuple
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, lax
from prosemble.models.prototype_base import UnsupervisedPrototypeModel
from prosemble.core.distance import squared_euclidean_distance_matrix
class NGState(NamedTuple):
"""State for Neural Gas lax.scan loop."""
prototypes: jnp.ndarray
loss: jnp.ndarray
prev_loss: jnp.ndarray
converged: jnp.ndarray
iteration: jnp.ndarray
[docs]
class NeuralGas(UnsupervisedPrototypeModel):
"""Neural Gas.
Updates all prototypes based on rank-distance:
.. math::
h(\\text{rank}, \\lambda) = \\exp(-\\text{rank} / \\lambda)
.. math::
w_k \\leftarrow w_k + \\varepsilon \\cdot h(\\text{rank}_k) \\cdot (x - w_k)
Both :math:`\\varepsilon` and :math:`\\lambda` decay exponentially during training.
Parameters
----------
lr_init : float
Initial learning rate.
lr_final : float
Final learning rate.
lambda_init : float, optional
Initial neighborhood range. Default: n_prototypes / 2.
lambda_final : float
Final neighborhood range.
n_prototypes : int
Number of prototypes/nodes.
max_iter : int
Maximum training iterations.
lr : float
Initial learning rate.
epsilon : float
Convergence threshold.
random_seed : int
Random seed.
distance_fn : callable, optional
Distance function.
callbacks : list, optional
Callback objects.
use_scan : bool
If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore parameters from the lowest-loss epoch. Default: False.
"""
def __init__(self, n_prototypes, lr_init=0.5, lr_final=0.01,
lambda_init=None, lambda_final=0.01,
max_iter=100, lr=0.01, epsilon=1e-6, random_seed=42,
distance_fn=None, callbacks=None, use_scan=True,
patience=None, restore_best=False):
super().__init__(
n_prototypes=n_prototypes, max_iter=max_iter, lr=lr,
epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn,
callbacks=callbacks, use_scan=use_scan, patience=patience,
restore_best=restore_best,
)
self.lr_init = lr_init
self.lr_final = lr_final
self.lambda_init = lambda_init # defaults to n_prototypes/2
self.lambda_final = lambda_final
@partial(jit, static_argnums=(0,))
def _ng_step(self, state, X, lambda_init):
"""Single JIT-compiled Neural Gas training step."""
t = state.iteration
max_t = jnp.array(max(self.max_iter - 1, 1), dtype=jnp.float32)
frac = t.astype(jnp.float32) / max_t
lr_t = self.lr_init * (self.lr_final / self.lr_init) ** frac
lam_t = lambda_init * (self.lambda_final / lambda_init) ** frac
prototypes = state.prototypes
distances = squared_euclidean_distance_matrix(X, prototypes)
# Rank prototypes for each sample
order = jnp.argsort(distances, axis=1)
ranks = jnp.argsort(order, axis=1).astype(jnp.float32)
h = jnp.exp(-ranks / lam_t)
# Weighted update
diffs = X[:, None, :] - prototypes[None, :, :]
weighted_diffs = h[:, :, None] * diffs
update = lr_t * jnp.mean(weighted_diffs, axis=0)
new_prototypes = prototypes + update
energy = jnp.sum(h * distances)
# Convergence
has_converged = state.converged | (
jnp.abs(energy - state.prev_loss) < self.epsilon
)
frozen_prototypes = jnp.where(state.converged, prototypes, new_prototypes)
frozen_energy = jnp.where(state.converged, state.loss, energy)
new_state = NGState(
prototypes=frozen_prototypes,
loss=frozen_energy,
prev_loss=energy,
converged=has_converged,
iteration=t + 1,
)
return new_state, frozen_energy
@partial(jit, static_argnums=(0,))
def _fit_scan(self, X, prototypes, lambda_init):
"""Scan-based training loop."""
initial_state = NGState(
prototypes=prototypes,
loss=jnp.array(float('inf')),
prev_loss=jnp.array(float('inf')),
converged=jnp.array(False),
iteration=jnp.array(0),
)
def scan_fn(state, _):
return self._ng_step(state, X, lambda_init)
final_state, loss_history = lax.scan(
scan_fn, initial_state, None, length=self.max_iter
)
return final_state, loss_history
[docs]
def fit(self, X):
"""Fit Neural Gas."""
X = jnp.asarray(X, dtype=jnp.float32)
n_samples = X.shape[0]
key = self.key
indices = jax.random.choice(key, n_samples, (self.n_prototypes,), replace=False)
prototypes = X[indices]
lambda_init_val = self.lambda_init if self.lambda_init else self.n_prototypes / 2.0
if self.use_scan:
return self._fit_with_scan(X, prototypes, lambda_init_val)
else:
return self._fit_with_python_loop(X, prototypes, lambda_init_val)
def _fit_with_scan(self, X, prototypes, lambda_init_val):
"""lax.scan training: JIT-compiled, runs all max_iter iterations."""
lambda_init = jnp.array(lambda_init_val, dtype=jnp.float32)
final_state, loss_history = self._fit_scan(X, prototypes, lambda_init)
converged_mask = jnp.abs(jnp.diff(loss_history)) < self.epsilon
first_converged = jnp.argmax(converged_mask)
has_any = jnp.any(converged_mask)
n_iter = jnp.where(has_any, first_converged + 2, self.max_iter)
self.prototypes_ = final_state.prototypes
self.n_iter_ = int(n_iter)
self.loss_ = float(final_state.loss)
self.loss_history_ = loss_history
return self
def _fit_with_python_loop(self, X, prototypes, lambda_init_val):
"""Python for-loop training: true early stopping, no wasted compute."""
loss_history = []
for t in range(self.max_iter):
frac = t / max(self.max_iter - 1, 1)
lr_t = self.lr_init * (self.lr_final / self.lr_init) ** frac
lam_t = lambda_init_val * (self.lambda_final / lambda_init_val) ** frac
distances = squared_euclidean_distance_matrix(X, prototypes)
order = jnp.argsort(distances, axis=1)
ranks = jnp.argsort(order, axis=1).astype(jnp.float32)
h = jnp.exp(-ranks / lam_t)
diffs = X[:, None, :] - prototypes[None, :, :]
weighted_diffs = h[:, :, None] * diffs
update = lr_t * jnp.mean(weighted_diffs, axis=0)
prototypes = prototypes + update
energy = float(jnp.sum(h * distances))
loss_history.append(energy)
if t > 0 and abs(loss_history[-1] - loss_history[-2]) < self.epsilon:
break
self.prototypes_ = prototypes
self.n_iter_ = t + 1
self.loss_ = loss_history[-1]
self.loss_history_ = jnp.array(loss_history)
return self
def _get_hyperparams(self):
hp = super()._get_hyperparams()
hp.update({
'lr_init': self.lr_init,
'lr_final': self.lr_final,
'lambda_final': self.lambda_final,
})
if self.lambda_init is not None:
hp['lambda_init'] = self.lambda_init
return hp