Source code for prosemble.models.kmeans

"""
JAX implementation of K-means++ (Hard C-Means with K-means++ initialization)

This is a GPU-accelerated implementation using JAX.
K-means++ provides better initialization than random selection.
"""

# Author: Nana Abeka Otoo <abekaotoo@gmail.com>
# License: MIT

from typing import Self

import jax
import jax.numpy as jnp
import chex
from jax import jit
from functools import partial

from prosemble.core.distance import batch_squared_euclidean
from .hcm import HCM


[docs] class KMeansPlusPlus: """ K-means++ clustering with JAX (Hard C-Means with smart initialization) K-means++ is an algorithm for choosing initial cluster centers with better convergence properties than random initialization. Algorithm: 1. Choose first center uniformly at random from data points 2. For each data point x, compute D(x), the distance to nearest center 3. Choose next center with probability proportional to D(x)² 4. Repeat until k centers are chosen 5. Run standard k-means (HCM) with these initial centers Parameters ---------- n_clusters : int, default=3 Number of clusters max_iter : int, default=100 Maximum number of iterations epsilon : float, default=1e-5 Convergence tolerance random_seed : int, optional Random seed for reproducibility plot_steps : bool, default=False Whether to enable visualization (requires LiveVisualizer) Attributes ---------- centroids_ : array of shape (n_clusters, n_features) Cluster centers labels_ : array of shape (n_samples,) Labels of each point n_iter_ : int Number of iterations run objective_ : float Final objective function value See Also -------- HCM : Hard C-Means model used internally after initialization. """ def __init__( self, n_clusters: int = 3, max_iter: int = 100, epsilon: float = 1e-5, random_seed: int | None = None, plot_steps: bool = False ): self.n_clusters = n_clusters self.max_iter = max_iter self.epsilon = epsilon self.random_seed = random_seed self.plot_steps = plot_steps # Fitted attributes self.centroids_ = None self.labels_ = None self.n_iter_ = 0 self.objective_ = None self._hcm = None def _kmeans_plusplus_init(self, X: chex.Array, key: chex.PRNGKey) -> chex.Array: """ Initialize centroids using k-means++ algorithm. Uses a Python loop (not lax.fori_loop) because each iteration changes the number of selected centroids. Parameters ---------- X : array of shape (n_samples, n_features) Training data key : PRNGKey Random key for JAX Returns ------- centroids : array of shape (n_clusters, n_features) Initial cluster centers """ n_samples, n_features = X.shape # Step 1: Choose first center uniformly at random key, subkey = jax.random.split(key) first_idx = jax.random.choice(subkey, n_samples) centroids = X[first_idx:first_idx+1] # Shape: (1, n_features) # Steps 2-4: Choose remaining centers for _ in range(self.n_clusters - 1): # Compute squared distances to nearest centroid D_sq = batch_squared_euclidean(X, centroids) # (n_samples, n_centers) min_distances = jnp.min(D_sq, axis=1) # (n_samples,) # Sample proportional to squared distance key, subkey = jax.random.split(key) probs = min_distances / jnp.maximum(jnp.sum(min_distances), 1e-10) next_idx = jax.random.choice(subkey, n_samples, p=probs) new_cent = X[next_idx:next_idx+1] centroids = jnp.concatenate([centroids, new_cent], axis=0) return centroids
[docs] def fit(self, X: chex.Array) -> Self: """ Fit K-means++ model to data. Parameters ---------- X : array of shape (n_samples, n_features) Training data Returns ------- self : object Fitted estimator """ # Initialize centroids using k-means++ seed = self.random_seed if self.random_seed is not None else 42 key = jax.random.PRNGKey(seed) initial_centroids = self._kmeans_plusplus_init(X, key) # Use HCM with k-means++ initialization self._hcm = HCM( n_clusters=self.n_clusters, max_iter=self.max_iter, epsilon=self.epsilon, random_seed=self.random_seed, plot_steps=self.plot_steps ) # Pass k-means++ centroids to HCM self._hcm.fit(X, initial_centroids=initial_centroids) # Copy results self.centroids_ = self._hcm.centroids_ self.labels_ = self._hcm.labels_ self.n_iter_ = self._hcm.n_iter_ self.objective_ = self._hcm.objective_ return self
[docs] def predict(self, X: chex.Array) -> chex.Array: """ Predict cluster labels for samples. Parameters ---------- X : array of shape (n_samples, n_features) New data to predict Returns ------- labels : array of shape (n_samples,) Index of the cluster each sample belongs to """ if self._hcm is None: raise ValueError("Model not fitted yet. Call fit() first.") return self._hcm.predict(X)
def get_objective_history(self): """Get the objective function history.""" if self._hcm is None: raise ValueError("Model not fitted yet. Call fit() first.") return self._hcm.get_objective_history() def final_centroids(self): """Get final cluster centroids.""" if self.centroids_ is None: raise ValueError("Model not fitted yet. Call fit() first.") return self.centroids_ def get_distance_space(self, X: chex.Array) -> chex.Array: """ Compute distance space (distances to all centroids). Parameters ---------- X : array of shape (n_samples, n_features) Input data Returns ------- distances : array of shape (n_samples, n_clusters) Distances to each centroid """ if self._hcm is None: raise ValueError("Model not fitted yet. Call fit() first.") return self._hcm.get_distance_space(X)
# Alias for backward compatibility kmeans_plusplus_jax = KMeansPlusPlus Kmeans = KMeansPlusPlus