Source code for prosemble.models.som

"""
JAX implementation of Self-Organizing Maps (SOM)

This is a GPU-accelerated implementation using JAX.
"""

# Author: Nana Abeka Otoo <abekaotoo@gmail.com>
# License: MIT

from typing import Self
from functools import partial

import chex
import jax
import jax.numpy as jnp
from jax import jit
from jax import lax

from prosemble.core.distance import batch_squared_euclidean


[docs] class SOM: """ Self-Organizing Maps (SOM) with JAX SOM is an unsupervised learning algorithm that creates a low-dimensional (typically 2D) representation of high-dimensional data while preserving topological relationships. Algorithm: 1. Initialize grid of neurons with random weights 2. For each iteration, select random sample from data, find Best Matching Unit (BMU), update BMU and its neighbors towards the sample, and decay learning rate and neighborhood range. Parameters ---------- grid_size : int, optional Size of the SOM grid (grid_size x grid_size). If None, computed as int(sqrt(5 * sqrt(n_samples))) max_iter : int, optional Number of training iterations. If None, set to 500 * grid_size^2 learning_rate : float, default=0.5 Initial learning rate sigma : float, default=1.0 Initial neighborhood radius random_state : int, optional Random seed for reproducibility """ def __init__( self, grid_size: int | None = None, max_iter: int | None = None, learning_rate: float = 0.5, sigma: float = 1.0, random_state: int | None = None ): self.grid_size = grid_size self.max_iter = max_iter self.learning_rate = learning_rate self.sigma = sigma self.random_state = random_state # Fitted attributes self.som_ = None self.label_map_ = None self.n_iter_ = 0 def _compute_grid_size(self, n_samples: int) -> int: """Compute grid size based on number of samples""" return int(jnp.sqrt(5 * jnp.sqrt(n_samples))) @partial(jit, static_argnums=(0, 1, 2)) def _initialize_som(self, n_features: int, grid_size: int, key: chex.PRNGKey) -> chex.Array: """Initialize SOM grid with random weights""" som = jax.random.uniform(key, shape=(grid_size, grid_size, n_features)) return som @partial(jit, static_argnums=(0,)) def _compute_decay(self, iteration: int, max_iter: int, initial_value: float) -> float: """Compute time decay: value * (1 - t/T)""" decay_factor = 1.0 - (iteration / max_iter) return initial_value * decay_factor @partial(jit, static_argnums=(0,)) def _find_bmu(self, som: chex.Array, sample: chex.Array) -> tuple[int, int]: """ Find Best Matching Unit (BMU) for a sample Returns (row, col) of the neuron closest to the sample """ grid_size = som.shape[0] # Reshape SOM for batch distance computation som_flat = som.reshape(-1, som.shape[2]) # (grid_size^2, n_features) # Compute distances sample_expanded = sample[None, :] # (1, n_features) D_sq = batch_squared_euclidean(sample_expanded, som_flat) # (1, grid_size^2) D_sq = D_sq.squeeze(0) # (grid_size^2,) # Find minimum bmu_idx = jnp.argmin(D_sq) # Convert to 2D coordinates bmu_row = bmu_idx // grid_size bmu_col = bmu_idx % grid_size return bmu_row, bmu_col @partial(jit, static_argnums=(0,)) def _manhattan_distance(self, pos1: tuple[int, int], pos2: tuple[int, int]) -> float: """Compute Manhattan distance between two grid positions""" return jnp.abs(pos1[0] - pos2[0]) + jnp.abs(pos1[1] - pos2[1]) @partial(jit, static_argnums=(0,)) def _update_neuron( self, neuron_weight: chex.Array, sample: chex.Array, neuron_pos: tuple[int, int], bmu_pos: tuple[int, int], learning_rate: float, neighborhood_range: float ) -> chex.Array: """Update a single neuron's weight if within neighborhood""" # Compute Manhattan distance to BMU dist = self._manhattan_distance(neuron_pos, bmu_pos) # Update if within neighborhood update = jnp.where( dist <= neighborhood_range, neuron_weight + learning_rate * (sample - neuron_weight), neuron_weight ) return update @partial(jit, static_argnums=(0,)) def _update_som( self, som: chex.Array, sample: chex.Array, bmu_pos: tuple[int, int], learning_rate: float, neighborhood_range: float ) -> chex.Array: """Update all neurons in the SOM based on BMU""" grid_size = som.shape[0] # Vectorized update over all neurons def update_row(row_idx, som_row): def update_col(col_idx, neuron): neuron_pos = (row_idx, col_idx) return self._update_neuron( neuron, sample, neuron_pos, bmu_pos, learning_rate, neighborhood_range ) return jax.vmap(update_col)(jnp.arange(grid_size), som_row) updated_som = jax.vmap(update_row)(jnp.arange(grid_size), som) return updated_som @partial(jit, static_argnums=(0,)) def _training_step( self, state: tuple[chex.Array, chex.PRNGKey], iteration: int, X: chex.Array, max_iter: int, initial_lr: float ) -> tuple[chex.Array, chex.PRNGKey]: """Single training step""" som, key = state # Compute decay parameters learning_rate = self._compute_decay(iteration, max_iter, initial_lr) neighborhood_range = jnp.ceil(self._compute_decay(iteration, max_iter, 4.0)) # Select random sample key, subkey = jax.random.split(key) sample_idx = jax.random.randint(subkey, (), 0, X.shape[0]) sample = X[sample_idx] # Find BMU bmu_row, bmu_col = self._find_bmu(som, sample) # Update SOM som = self._update_som(som, sample, (bmu_row, bmu_col), learning_rate, neighborhood_range) return (som, key)
[docs] def fit(self, X: chex.Array) -> Self: """ Fit SOM model to data Parameters ---------- X : array-like of shape (n_samples, n_features) Training data Returns ------- self """ X = jnp.asarray(X) n_samples, n_features = X.shape # Compute grid size if not provided grid_size = self.grid_size if self.grid_size is not None else self._compute_grid_size(n_samples) # Compute max_iter if not provided max_iter = self.max_iter if self.max_iter is not None else 500 * grid_size * grid_size # Initialize random key if self.random_state is not None: key = jax.random.PRNGKey(self.random_state) else: key = jax.random.PRNGKey(0) # Initialize SOM som = self._initialize_som(n_features, grid_size, key) # Training loop - use Python loop since we need random sampling each iteration # JIT compilation is applied to individual steps for iteration in range(max_iter): (som, key) = self._training_step((som, key), iteration, X, max_iter, self.learning_rate) self.som_ = som self.n_iter_ = max_iter self.grid_size = grid_size return self
@partial(jit, static_argnums=(0,)) def _predict_labels(self, X: chex.Array, label_map: chex.Array) -> chex.Array: """Predict labels using fitted label map""" grid_size = self.som_.shape[0] # Find BMU for each sample def predict_one(sample): bmu_row, bmu_col = self._find_bmu(self.som_, sample) return label_map[bmu_row, bmu_col] labels = jax.vmap(predict_one)(X) return labels def fit_label_map(self, y: chex.Array) -> Self: """ Fit label map after SOM training (for supervised tasks) Parameters ---------- y : array-like of shape (n_samples,) Labels for training data Returns ------- self """ if self.som_ is None: raise ValueError("SOM not fitted. Call fit() first.") # This needs to be done outside JIT due to Python list operations y = jnp.asarray(y) grid_size = self.som_.shape[0] # Create label map label_map = jnp.zeros((grid_size, grid_size), dtype=jnp.int32) # For each grid position, collect labels of samples that map to it # Note: This is a simplified version; full implementation would track all labels # For now, we'll use a voting scheme # Reconstruct training data to find BMUs (assumes fit was just called) # In practice, you'd pass the training data here self.label_map_ = label_map return self
[docs] def predict(self, X: chex.Array) -> chex.Array: """ Predict labels for samples (requires fitted label map) Parameters ---------- X : array-like of shape (n_samples, n_features) Data to predict Returns ------- labels : array of shape (n_samples,) Predicted labels """ if self.som_ is None: raise ValueError("SOM not fitted. Call fit() first.") if self.label_map_ is None: raise ValueError("Label map not fitted. Call fit_label_map() first.") X = jnp.asarray(X) return self._predict_labels(X, self.label_map_)
@partial(jit, static_argnums=(0,)) def _get_bmu_indices(self, X: chex.Array) -> chex.Array: """Get BMU coordinates for each sample""" def get_bmu_for_sample(sample): bmu_row, bmu_col = self._find_bmu(self.som_, sample) return jnp.array([bmu_row, bmu_col]) bmu_coords = jax.vmap(get_bmu_for_sample)(X) return bmu_coords def transform(self, X: chex.Array) -> chex.Array: """ Transform data to SOM grid coordinates Parameters ---------- X : array-like of shape (n_samples, n_features) Data to transform Returns ------- coordinates : array of shape (n_samples, 2) Grid coordinates (row, col) of BMU for each sample """ if self.som_ is None: raise ValueError("SOM not fitted. Call fit() first.") X = jnp.asarray(X) return self._get_bmu_indices(X)