"""
JAX implementation of Self-Organizing Maps (SOM)
This is a GPU-accelerated implementation using JAX.
"""
# Author: Nana Abeka Otoo <abekaotoo@gmail.com>
# License: MIT
from typing import Self
from functools import partial
import chex
import jax
import jax.numpy as jnp
from jax import jit
from jax import lax
from prosemble.core.distance import batch_squared_euclidean
[docs]
class SOM:
"""
Self-Organizing Maps (SOM) with JAX
SOM is an unsupervised learning algorithm that creates a low-dimensional
(typically 2D) representation of high-dimensional data while preserving
topological relationships.
Algorithm:
1. Initialize grid of neurons with random weights
2. For each iteration, select random sample from data,
find Best Matching Unit (BMU), update BMU and its neighbors
towards the sample, and decay learning rate and neighborhood range.
Parameters
----------
grid_size : int, optional
Size of the SOM grid (grid_size x grid_size).
If None, computed as int(sqrt(5 * sqrt(n_samples)))
max_iter : int, optional
Number of training iterations.
If None, set to 500 * grid_size^2
learning_rate : float, default=0.5
Initial learning rate
sigma : float, default=1.0
Initial neighborhood radius
random_state : int, optional
Random seed for reproducibility
"""
def __init__(
self,
grid_size: int | None = None,
max_iter: int | None = None,
learning_rate: float = 0.5,
sigma: float = 1.0,
random_state: int | None = None
):
self.grid_size = grid_size
self.max_iter = max_iter
self.learning_rate = learning_rate
self.sigma = sigma
self.random_state = random_state
# Fitted attributes
self.som_ = None
self.label_map_ = None
self.n_iter_ = 0
def _compute_grid_size(self, n_samples: int) -> int:
"""Compute grid size based on number of samples"""
return int(jnp.sqrt(5 * jnp.sqrt(n_samples)))
@partial(jit, static_argnums=(0, 1, 2))
def _initialize_som(self, n_features: int, grid_size: int, key: chex.PRNGKey) -> chex.Array:
"""Initialize SOM grid with random weights"""
som = jax.random.uniform(key, shape=(grid_size, grid_size, n_features))
return som
@partial(jit, static_argnums=(0,))
def _compute_decay(self, iteration: int, max_iter: int, initial_value: float) -> float:
"""Compute time decay: value * (1 - t/T)"""
decay_factor = 1.0 - (iteration / max_iter)
return initial_value * decay_factor
@partial(jit, static_argnums=(0,))
def _find_bmu(self, som: chex.Array, sample: chex.Array) -> tuple[int, int]:
"""
Find Best Matching Unit (BMU) for a sample
Returns (row, col) of the neuron closest to the sample
"""
grid_size = som.shape[0]
# Reshape SOM for batch distance computation
som_flat = som.reshape(-1, som.shape[2]) # (grid_size^2, n_features)
# Compute distances
sample_expanded = sample[None, :] # (1, n_features)
D_sq = batch_squared_euclidean(sample_expanded, som_flat) # (1, grid_size^2)
D_sq = D_sq.squeeze(0) # (grid_size^2,)
# Find minimum
bmu_idx = jnp.argmin(D_sq)
# Convert to 2D coordinates
bmu_row = bmu_idx // grid_size
bmu_col = bmu_idx % grid_size
return bmu_row, bmu_col
@partial(jit, static_argnums=(0,))
def _manhattan_distance(self, pos1: tuple[int, int], pos2: tuple[int, int]) -> float:
"""Compute Manhattan distance between two grid positions"""
return jnp.abs(pos1[0] - pos2[0]) + jnp.abs(pos1[1] - pos2[1])
@partial(jit, static_argnums=(0,))
def _update_neuron(
self,
neuron_weight: chex.Array,
sample: chex.Array,
neuron_pos: tuple[int, int],
bmu_pos: tuple[int, int],
learning_rate: float,
neighborhood_range: float
) -> chex.Array:
"""Update a single neuron's weight if within neighborhood"""
# Compute Manhattan distance to BMU
dist = self._manhattan_distance(neuron_pos, bmu_pos)
# Update if within neighborhood
update = jnp.where(
dist <= neighborhood_range,
neuron_weight + learning_rate * (sample - neuron_weight),
neuron_weight
)
return update
@partial(jit, static_argnums=(0,))
def _update_som(
self,
som: chex.Array,
sample: chex.Array,
bmu_pos: tuple[int, int],
learning_rate: float,
neighborhood_range: float
) -> chex.Array:
"""Update all neurons in the SOM based on BMU"""
grid_size = som.shape[0]
# Vectorized update over all neurons
def update_row(row_idx, som_row):
def update_col(col_idx, neuron):
neuron_pos = (row_idx, col_idx)
return self._update_neuron(
neuron, sample, neuron_pos, bmu_pos,
learning_rate, neighborhood_range
)
return jax.vmap(update_col)(jnp.arange(grid_size), som_row)
updated_som = jax.vmap(update_row)(jnp.arange(grid_size), som)
return updated_som
@partial(jit, static_argnums=(0,))
def _training_step(
self,
state: tuple[chex.Array, chex.PRNGKey],
iteration: int,
X: chex.Array,
max_iter: int,
initial_lr: float
) -> tuple[chex.Array, chex.PRNGKey]:
"""Single training step"""
som, key = state
# Compute decay parameters
learning_rate = self._compute_decay(iteration, max_iter, initial_lr)
neighborhood_range = jnp.ceil(self._compute_decay(iteration, max_iter, 4.0))
# Select random sample
key, subkey = jax.random.split(key)
sample_idx = jax.random.randint(subkey, (), 0, X.shape[0])
sample = X[sample_idx]
# Find BMU
bmu_row, bmu_col = self._find_bmu(som, sample)
# Update SOM
som = self._update_som(som, sample, (bmu_row, bmu_col), learning_rate, neighborhood_range)
return (som, key)
[docs]
def fit(self, X: chex.Array) -> Self:
"""
Fit SOM model to data
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data
Returns
-------
self
"""
X = jnp.asarray(X)
n_samples, n_features = X.shape
# Compute grid size if not provided
grid_size = self.grid_size if self.grid_size is not None else self._compute_grid_size(n_samples)
# Compute max_iter if not provided
max_iter = self.max_iter if self.max_iter is not None else 500 * grid_size * grid_size
# Initialize random key
if self.random_state is not None:
key = jax.random.PRNGKey(self.random_state)
else:
key = jax.random.PRNGKey(0)
# Initialize SOM
som = self._initialize_som(n_features, grid_size, key)
# Training loop - use Python loop since we need random sampling each iteration
# JIT compilation is applied to individual steps
for iteration in range(max_iter):
(som, key) = self._training_step((som, key), iteration, X, max_iter, self.learning_rate)
self.som_ = som
self.n_iter_ = max_iter
self.grid_size = grid_size
return self
@partial(jit, static_argnums=(0,))
def _predict_labels(self, X: chex.Array, label_map: chex.Array) -> chex.Array:
"""Predict labels using fitted label map"""
grid_size = self.som_.shape[0]
# Find BMU for each sample
def predict_one(sample):
bmu_row, bmu_col = self._find_bmu(self.som_, sample)
return label_map[bmu_row, bmu_col]
labels = jax.vmap(predict_one)(X)
return labels
def fit_label_map(self, y: chex.Array) -> Self:
"""
Fit label map after SOM training (for supervised tasks)
Parameters
----------
y : array-like of shape (n_samples,)
Labels for training data
Returns
-------
self
"""
if self.som_ is None:
raise ValueError("SOM not fitted. Call fit() first.")
# This needs to be done outside JIT due to Python list operations
y = jnp.asarray(y)
grid_size = self.som_.shape[0]
# Create label map
label_map = jnp.zeros((grid_size, grid_size), dtype=jnp.int32)
# For each grid position, collect labels of samples that map to it
# Note: This is a simplified version; full implementation would track all labels
# For now, we'll use a voting scheme
# Reconstruct training data to find BMUs (assumes fit was just called)
# In practice, you'd pass the training data here
self.label_map_ = label_map
return self
[docs]
def predict(self, X: chex.Array) -> chex.Array:
"""
Predict labels for samples (requires fitted label map)
Parameters
----------
X : array-like of shape (n_samples, n_features)
Data to predict
Returns
-------
labels : array of shape (n_samples,)
Predicted labels
"""
if self.som_ is None:
raise ValueError("SOM not fitted. Call fit() first.")
if self.label_map_ is None:
raise ValueError("Label map not fitted. Call fit_label_map() first.")
X = jnp.asarray(X)
return self._predict_labels(X, self.label_map_)
@partial(jit, static_argnums=(0,))
def _get_bmu_indices(self, X: chex.Array) -> chex.Array:
"""Get BMU coordinates for each sample"""
def get_bmu_for_sample(sample):
bmu_row, bmu_col = self._find_bmu(self.som_, sample)
return jnp.array([bmu_row, bmu_col])
bmu_coords = jax.vmap(get_bmu_for_sample)(X)
return bmu_coords
def transform(self, X: chex.Array) -> chex.Array:
"""
Transform data to SOM grid coordinates
Parameters
----------
X : array-like of shape (n_samples, n_features)
Data to transform
Returns
-------
coordinates : array of shape (n_samples, 2)
Grid coordinates (row, col) of BMU for each sample
"""
if self.som_ is None:
raise ValueError("SOM not fitted. Call fit() first.")
X = jnp.asarray(X)
return self._get_bmu_indices(X)