Source code for prosemble.models.hcm

"""
JAX-based Hard C-Means (HCM) / K-Means clustering implementation.

This module provides a GPU-accelerated implementation of Hard C-Means (K-Means)
using JAX with JIT compilation for high performance.
"""

from typing import NamedTuple, Self
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import chex
from jax import jit, lax

from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin


class HCMState(NamedTuple):
    """Immutable state for HCM iteration.

    Attributes:
        centroids: Cluster centroids, shape (n_clusters, n_features)
        labels: Hard cluster assignments, shape (n_samples,)
        objective: Sum of squared distances to assigned centroids
        iteration: Current iteration number
        converged: Whether algorithm has converged
    """
    centroids: chex.Array
    labels: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool


[docs] class HCM(ScanFitMixin, FuzzyClusteringBase): """ Hard C-Means (K-Means) clustering with JAX. HCM assigns each data point to exactly one cluster (hard assignment) based on the nearest centroid. This is the classic K-Means algorithm. Algorithm: 1. Initialize centroids randomly or from data 2. Assign each point to nearest centroid 3. Update centroids as mean of assigned points 4. Repeat until convergence Objective function: .. math:: J = \\sum_i \\|x_i - v_{l_i}\\|^2 Parameters ---------- init_method : {'random', 'kmeans++'}, default='random' Method for initializing centroids. n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. Attributes ---------- centroids_ : array, shape (n_clusters, n_features) Final cluster centroids labels_ : array, shape (n_samples,) Hard cluster assignments for training data n_iter_ : int Number of iterations until convergence objective_ : float Final objective function value objective_history_ : array Objective value at each iteration Examples -------- >>> import jax.numpy as jnp >>> from prosemble.models import HCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = HCM(n_clusters=2, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X) """ _hyperparams = ('init_method',) _fitted_array_names = ('labels_',) def __init__( self, n_clusters: int, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'random', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str | None = None, callbacks=None, ): # Validate model-specific parameters if init_method not in ['random', 'kmeans++']: raise ValueError("init_method must be 'random' or 'kmeans++'") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.init_method = init_method # Model-specific fitted attributes self.labels_ = None def _initialize_centroids(self, X: chex.Array) -> chex.Array: """Initialize cluster centroids. Args: X: Input data, shape (n_samples, n_features) Returns: Initial centroids, shape (n_clusters, n_features) """ n_samples = X.shape[0] if self.init_method == 'random': # Randomly select data points as initial centroids indices = jax.random.choice( self.key, n_samples, shape=(self.n_clusters,), replace=False ) centroids = X[indices] elif self.init_method == 'kmeans++': # K-means++ initialization centroids = self._kmeans_plusplus_init(X) else: raise ValueError(f"Unknown init_method: {self.init_method}") return centroids def _kmeans_plusplus_init(self, X: chex.Array) -> chex.Array: """Initialize centroids using K-means++ algorithm. Args: X: Input data, shape (n_samples, n_features) Returns: Initial centroids, shape (n_clusters, n_features) """ n_samples, n_features = X.shape centroids = jnp.zeros((self.n_clusters, n_features)) # Choose first centroid uniformly at random key1, key2 = jax.random.split(self.key) first_idx = jax.random.choice(key1, n_samples) centroids = centroids.at[0].set(X[first_idx]) # Choose remaining centroids with probability proportional to distance squared for i in range(1, self.n_clusters): # Compute distances to nearest existing centroid D_sq = self.distance_fn(X, centroids[:i]) # (n_samples, i) min_distances = jnp.min(D_sq, axis=1) # (n_samples,) # Probability proportional to squared distance probs = min_distances / jnp.sum(min_distances) # Choose next centroid key2, subkey = jax.random.split(key2) next_idx = jax.random.choice(subkey, n_samples, p=probs) centroids = centroids.at[i].set(X[next_idx]) return centroids @partial(jit, static_argnums=(0,)) def _assign_labels(self, X: chex.Array, centroids: chex.Array) -> chex.Array: """Assign each data point to nearest centroid (hard assignment). Args: X: Input data, shape (n_samples, n_features) centroids: Current centroids, shape (n_clusters, n_features) Returns: labels: Hard cluster assignments, shape (n_samples,) """ # Compute squared distances to all centroids D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) # Assign to nearest centroid labels = jnp.argmin(D_sq, axis=1) # (n_samples,) return labels @partial(jit, static_argnums=(0,)) def _update_centroids(self, X: chex.Array, labels: chex.Array) -> chex.Array: """Update centroids as mean of assigned points. Args: X: Input data, shape (n_samples, n_features) labels: Hard cluster assignments, shape (n_samples,) Returns: centroids: Updated centroids, shape (n_clusters, n_features) """ n_features = X.shape[1] def compute_centroid(cluster_idx): # Get mask for points assigned to this cluster mask = labels == cluster_idx # (n_samples,) # Count points in cluster count = jnp.sum(mask) # Compute mean of assigned points # If no points assigned, keep old centroid (handled by where) sum_points = jnp.sum(jnp.where(mask[:, None], X, 0.0), axis=0) centroid = jnp.where(count > 0, sum_points / count, 0.0) return centroid # Vectorize over clusters centroids = jax.vmap(compute_centroid)(jnp.arange(self.n_clusters)) return centroids @partial(jit, static_argnums=(0,)) def _compute_objective( self, X: chex.Array, labels: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute HCM objective function. .. math:: J = \\sum_i \\|x_i - v_{l_i}\\|^2 Args: X: Input data, shape (n_samples, n_features) labels: Hard cluster assignments, shape (n_samples,) centroids: Current centroids, shape (n_clusters, n_features) Returns: objective: Scalar objective value """ # Get assigned centroids for each point assigned_centroids = centroids[labels] # (n_samples, n_features) # Compute squared distances diff = X - assigned_centroids sq_distances = jnp.sum(diff * diff, axis=1) # (n_samples,) # Sum over all points objective = jnp.sum(sq_distances) return objective @partial(jit, static_argnums=(0,)) def _iteration_step( self, state: HCMState, X: chex.Array ) -> tuple[HCMState, dict]: """Single HCM iteration step. Args: state: Current HCM state X: Input data, shape (n_samples, n_features) Returns: new_state: Updated HCM state metrics: Dictionary of metrics for this iteration """ # Assign labels based on current centroids labels = self._assign_labels(X, state.centroids) # Update centroids based on new assignments new_centroids = self._update_centroids(X, labels) # Compute objective objective = self._compute_objective(X, labels, new_centroids) # Check convergence based on centroid change centroid_change = jnp.linalg.norm(new_centroids - state.centroids, ord='fro') converged = centroid_change <= self.epsilon new_state = HCMState( centroids=new_centroids, labels=labels, objective=objective, iteration=state.iteration + 1, converged=converged ) metrics = { 'objective': objective, 'centroid_change': centroid_change, 'converged': converged } return new_state, metrics def _build_info(self, state, iteration): weights = jnp.ones(state.labels.shape[0]) return { 'centroids': state.centroids, 'labels': state.labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, }
[docs] def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self: """Fit HCM model to data. Parameters ---------- X : array-like, shape (n_samples, n_features) Training data initial_centroids : array-like, shape (n_clusters, n_features), optional Pre-computed centroids for warm starting resume : bool, default=False If True, resume from the model's current fitted state """ X = self._validate_input(X) if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") # Initialize centroids if resume: self._check_fitted() centroids_init = self.centroids_ elif initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) else: centroids_init = self._initialize_centroids(X) initial_labels = self._assign_labels(X, centroids_init) initial_objective = self._compute_objective(X, initial_labels, centroids_init) initial_state = HCMState( centroids=centroids_init, labels=initial_labels, objective=initial_objective, iteration=0, converged=False ) final_state, self.history_ = self._run_training(X, initial_state) # Store results self.centroids_ = final_state.centroids self.labels_ = final_state.labels self.n_iter_ = int(final_state.iteration) self.objective_ = float(final_state.objective) self.objective_history_ = self.history_['objective'] return self
[docs] def predict(self, X: chex.Array) -> chex.Array: """Predict cluster labels for new data. Args: X: Input data, shape (n_samples, n_features) Returns: labels: Cluster labels, shape (n_samples,) Raises: ValueError: If model has not been fitted """ self._check_fitted() X = jnp.asarray(X) labels = self._assign_labels(X, self.centroids_) return labels
def predict_proba(self, X: chex.Array) -> chex.Array: """Predict hard cluster membership (one-hot encoding). For HCM, this returns a one-hot encoding of the cluster assignments, with 1.0 for the assigned cluster and 0.0 for others. Args: X: Input data, shape (n_samples, n_features) Returns: membership: One-hot encoded assignments, shape (n_samples, n_clusters) Raises: ValueError: If model has not been fitted """ self._check_fitted() labels = self.predict(X) n_samples = X.shape[0] # Create one-hot encoding membership = jnp.zeros((n_samples, self.n_clusters)) membership = membership.at[jnp.arange(n_samples), labels].set(1.0) return membership def get_distance_space(self, X: chex.Array) -> chex.Array: """Compute distances to all cluster centroids. Args: X: Input data, shape (n_samples, n_features) Returns: distances: Euclidean distances to centroids, shape (n_samples, n_clusters) Raises: ValueError: If model has not been fitted """ self._check_fitted() X = jnp.asarray(X) # Compute squared distances D_sq = self.distance_fn(X, self.centroids_) # Return Euclidean distances distances = jnp.sqrt(D_sq) return distances