"""
Growing Neural Gas (GNG).
An incremental self-organizing network that can grow and shrink
by adding/removing nodes based on accumulated error.
References
----------
.. [1] Fritzke, B. (1995). A Growing Neural Gas Network Learns
Topologies. NIPS.
"""
import jax
import jax.numpy as jnp
import numpy as np
from prosemble.models.prototype_base import UnsupervisedPrototypeModel
from prosemble.core.distance import squared_euclidean_distance_matrix
[docs]
class GrowingNeuralGas(UnsupervisedPrototypeModel):
"""Growing Neural Gas.
Starts with 2 nodes and grows by inserting nodes near the
highest-error units. Connections between nodes have ages;
old connections are removed.
Parameters
----------
max_nodes : int
Maximum number of nodes.
lr_winner : float
Learning rate for the winning node.
lr_neighbor : float
Learning rate for neighbors of the winner.
max_age : int
Maximum age before an edge is removed.
insert_interval : int
Insert a new node every this many steps.
error_decay : float
Error decay factor applied to all nodes.
n_prototypes : int
Number of prototypes/nodes.
max_iter : int
Maximum training iterations.
lr : float
Initial learning rate.
epsilon : float
Convergence threshold.
random_seed : int
Random seed.
distance_fn : callable, optional
Distance function.
callbacks : list, optional
Callback objects.
use_scan : bool
If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore parameters from the lowest-loss epoch. Default: False.
"""
def __init__(self, max_nodes=100, lr_winner=0.1, lr_neighbor=0.01,
max_age=50, insert_interval=100, error_decay=0.995,
n_prototypes=2, max_iter=100, lr=0.01, epsilon=1e-6,
random_seed=42, distance_fn=None, callbacks=None,
use_scan=True, patience=None, restore_best=False):
super().__init__(
n_prototypes=n_prototypes, max_iter=max_iter, lr=lr,
epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn,
callbacks=callbacks, use_scan=use_scan, patience=patience,
restore_best=restore_best,
)
self.max_nodes = max_nodes
self.lr_winner = lr_winner
self.lr_neighbor = lr_neighbor
self.max_age = max_age
self.insert_interval = insert_interval
self.error_decay = error_decay
# Topology
self.edges_ = None
self.n_active_ = None
[docs]
def fit(self, X):
"""Fit Growing Neural Gas."""
X = jnp.asarray(X, dtype=jnp.float32)
n_samples, n_features = X.shape
# Pre-allocate for max_nodes
key = self.key
key1, key2 = jax.random.split(key)
idx = jax.random.choice(key1, n_samples, (2,), replace=False)
prototypes = np.zeros((self.max_nodes, n_features), dtype=np.float32)
prototypes[0] = np.array(X[idx[0]])
prototypes[1] = np.array(X[idx[1]])
edges = np.full((self.max_nodes, self.max_nodes), -1, dtype=np.int32) # -1 = no edge
errors = np.zeros(self.max_nodes, dtype=np.float32)
n_active = 2
step = 0
loss_history = []
for epoch in range(self.max_iter):
epoch_error = 0.0
# Shuffle data
perm_key = jax.random.fold_in(key2, epoch)
perm = jax.random.permutation(perm_key, n_samples)
for i in range(n_samples):
x = np.array(X[perm[i]])
step += 1
# Find two closest nodes
active_protos = prototypes[:n_active]
dists = np.sum((active_protos - x) ** 2, axis=1)
sorted_idx = np.argsort(dists)
s1, s2 = sorted_idx[0], sorted_idx[1]
# Accumulate error for winner
errors[s1] += dists[s1]
epoch_error += dists[s1]
# Create/refresh edge between s1 and s2
edges[s1, s2] = 0
edges[s2, s1] = 0
# Move winner and its neighbors
prototypes[s1] += self.lr_winner * (x - prototypes[s1])
for j in range(n_active):
if edges[s1, j] >= 0: # neighbor
prototypes[j] += self.lr_neighbor * (x - prototypes[j])
# Age all edges from s1
for j in range(n_active):
if edges[s1, j] >= 0:
edges[s1, j] += 1
edges[j, s1] += 1
# Remove old edges
old_mask = edges >= self.max_age
edges[old_mask] = -1
# Remove isolated nodes (no edges)
# Skip for simplicity — just mark for potential cleanup
# Insert new node
if step % self.insert_interval == 0 and n_active < self.max_nodes:
# Find node with largest error
q = np.argmax(errors[:n_active])
# Find its neighbor with largest error
neighbors_q = np.where(edges[q, :n_active] >= 0)[0]
if len(neighbors_q) > 0:
f = neighbors_q[np.argmax(errors[neighbors_q])]
# Insert new node between q and f
new_idx = n_active
prototypes[new_idx] = 0.5 * (prototypes[q] + prototypes[f])
# Remove edge q-f, add edges q-new and f-new
edges[q, f] = -1
edges[f, q] = -1
edges[q, new_idx] = 0
edges[new_idx, q] = 0
edges[f, new_idx] = 0
edges[new_idx, f] = 0
# Distribute error
errors[new_idx] = 0.5 * (errors[q] + errors[f])
errors[q] *= 0.5
errors[f] *= 0.5
n_active += 1
# Decay all errors
errors[:n_active] *= self.error_decay
loss_history.append(epoch_error / n_samples)
if epoch > 0 and abs(loss_history[-1] - loss_history[-2]) < self.epsilon:
break
# Store only active nodes
self.prototypes_ = jnp.array(prototypes[:n_active])
self.edges_ = jnp.array(edges[:n_active, :n_active])
self.n_active_ = n_active
self.n_iter_ = epoch + 1
self.loss_ = loss_history[-1]
self.loss_history_ = jnp.array(loss_history)
return self
def _get_hyperparams(self):
hp = super()._get_hyperparams()
hp.update({
'max_nodes': self.max_nodes,
'lr_winner': self.lr_winner,
'lr_neighbor': self.lr_neighbor,
'max_age': self.max_age,
'insert_interval': self.insert_interval,
'error_decay': self.error_decay,
})
return hp