"""
JAX implementation of Fuzzy C-Means (FCM) clustering algorithm.
This module provides a GPU-accelerated, fully vectorized implementation of FCM
using JAX. All operations are JIT-compiled for maximum performance.
Mathematical Background:
-----------------------
FCM is a soft clustering algorithm that assigns membership degrees to data points
for each cluster. Unlike hard clustering (K-Means), each point can belong to
multiple clusters with varying degrees.
Objective Function:
.. math::
J(U, V) = \\sum_i \\sum_j u_{ij}^m \\|x_i - v_j\\|^2
Subject to:
.. math::
\\sum_j u_{ij} = 1 \\quad \\forall i \\quad \\text{(membership constraint)}
.. math::
u_{ij} \\in [0, 1] \\quad \\text{(fuzzy membership)}
Update Rules:
.. math::
v_j = \\frac{\\sum_i u_{ij}^m x_i}{\\sum_i u_{ij}^m}
.. math::
u_{ij} = \\frac{1}{\\sum_k \\left(\\frac{d_{ij}}{d_{ik}}\\right)^{2/(m-1)}}
where:
- :math:`U`: fuzzy membership matrix (n x c)
- :math:`V`: cluster centroids (c x d)
- :math:`m`: fuzzifier parameter (typically 2.0)
- :math:`d_{ij}`: distance from point :math:`x_i` to centroid :math:`v_j`
Author: Prosemble Contributors
License: MIT
"""
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
from functools import partial
from typing import NamedTuple, Self
import chex
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
class FCMState(NamedTuple):
"""
Immutable state for FCM algorithm.
JAX requires immutable state for JIT compilation and functional programming.
Using NamedTuple ensures state cannot be modified in-place.
Attributes:
centroids: (c, d) array of cluster centroids
U: (n, c) array of fuzzy membership values
objective: Scalar objective function value
iteration: Current iteration number
converged: Boolean indicating convergence
"""
centroids: chex.Array
U: chex.Array
objective: chex.Array
iteration: int
converged: bool
[docs]
class FCM(ScanFitMixin, FuzzyClusteringBase):
"""
JAX implementation of Fuzzy C-Means clustering.
This implementation provides:
- Full vectorization (no Python loops)
- JIT compilation for speed
- Automatic GPU acceleration
- Immutable state management
- Numerical stability
Key Differences from NumPy Version:
-----------------------------------
1. **Vectorization**: All operations use matrix operations
Old: Triple nested loops in centroid computation
New: Single matrix multiplication
2. **Functional**: Immutable state using NamedTuple
Old: In-place updates (self.fit_cent = ...)
New: Return new state objects
3. **JIT Compilation**: Functions compiled to machine code
Old: Interpreted Python loops
New: Compiled XLA code
4. **GPU Support**: Automatic device placement
Old: CPU-only NumPy
New: GPU/TPU with JAX
Parameters
----------
fuzzifier : float, default=2.0
Fuzzification parameter (:math:`m`). Must be > 1.
- :math:`m = 1`: Hard clustering (crisp membership)
- :math:`m = 2`: Standard fuzzy clustering
- :math:`m \\to \\infty`: Maximum fuzziness (equal membership)
init_method : str, default='random'
Initialization method for :math:`U` matrix:
- 'random': Random Dirichlet distribution
- 'kmeans++': K-means++ centroids then compute :math:`U`
n_clusters : int
Number of clusters (must be >= 2).
max_iter : int
Maximum number of iterations.
epsilon : float
Convergence threshold.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Pairwise distance function. Default: squared Euclidean.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore centroids from the lowest-objective epoch.
Default: False.
plot_steps : bool
Whether to visualize clustering progress. Default: False.
show_confidence : bool
Whether to show confidence in visualization. Default: True.
show_pca_variance : bool
Whether to show PCA variance in visualization. Default: True.
save_plot_path : str, optional
Path to save final plot.
callbacks : list, optional
List of Callback objects for monitoring/visualization.
Attributes
----------
centroids_ : array of shape (n_clusters, n_features)
Cluster centroids after fitting
U_ : array of shape (n_samples, n_clusters)
Fuzzy membership matrix after fitting
objective_ : float
Final objective function value
n_iter_ : int
Number of iterations performed
history_ : dict
Training history containing objective values and other metrics
Examples
--------
>>> import jax.numpy as jnp
>>> from prosemble.models import FCM
>>>
>>> # Generate sample data
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]])
>>>
>>> # Fit FCM model
>>> model = FCM(n_clusters=2, fuzzifier=2.0, max_iter=100)
>>> model.fit(X)
>>>
>>> # Get results
>>> labels = model.predict(X)
>>> centroids = model.final_centroids()
>>> membership = model.predict_proba(X)
>>>
>>> print(f"Labels: {labels}")
>>> print(f"Centroids shape: {centroids.shape}")
>>> print(f"Membership shape: {membership.shape}")
References
----------
Bezdek, J. C. (1981). Pattern Recognition with Fuzzy Objective
Function Algorithms. Plenum Press, New York.
Dunn, J. C. (1973). A Fuzzy Relative of the ISODATA Process and
Its Use in Detecting Compact Well-Separated Clusters.
"""
_hyperparams = ('fuzzifier', 'init_method')
_fitted_array_names = ('U_',)
def __init__(
self,
n_clusters: int,
fuzzifier: float = 2.0,
max_iter: int = 100,
epsilon: float = 1e-5,
init_method: str = 'random',
random_seed: int = 42,
distance_fn=None,
patience: int | None = None,
restore_best: bool = False,
plot_steps: bool = False,
show_confidence: bool = True,
show_pca_variance: bool = True,
save_plot_path: str = None,
callbacks=None,
):
# Model-specific validation first
if fuzzifier <= 1.0:
raise ValueError(f"fuzzifier must be > 1, got {fuzzifier}")
super().__init__(
n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon,
random_seed=random_seed, distance_fn=distance_fn, patience=patience,
restore_best=restore_best, plot_steps=plot_steps,
show_confidence=show_confidence, show_pca_variance=show_pca_variance,
save_plot_path=save_plot_path, callbacks=callbacks,
)
self.fuzzifier = fuzzifier
self.init_method = init_method
# Model-specific fitted attributes
self.U_ = None
self.history_ = None
@partial(jit, static_argnums=(0,))
def _initialize_U(self, X: chex.Array, key: chex.PRNGKey) -> chex.Array:
"""
Initialize fuzzy membership matrix U.
Uses Dirichlet distribution to ensure row sums equal 1:
:math:`U \\sim \\text{Dir}(\\alpha)` where :math:`\\alpha = [1, 1, \\ldots, 1]`
Mathematical Property:
.. math::
\\sum_j u_{ij} = 1 \\quad \\forall i
Args:
X: (n, d) data matrix
key: JAX random key for reproducibility
Returns:
U: (n, c) fuzzy membership matrix
"""
n_samples = X.shape[0]
# Dirichlet distribution ensures row sums = 1
U = jax.random.dirichlet(
key,
alpha=jnp.ones(self.n_clusters),
shape=(n_samples,)
)
return U
@partial(jit, static_argnums=(0,))
def _compute_centroids(self, X: chex.Array, U: chex.Array) -> chex.Array:
"""
Compute cluster centroids from fuzzy membership matrix.
Mathematical Formula:
.. math::
v_j = \\frac{\\sum_i u_{ij}^m x_i}{\\sum_i u_{ij}^m}
Vectorized Implementation:
:math:`V = (U^m)^T X / \\sum (U^m)^T`
Old Implementation (NumPy):
```python
fuzzified_assignments = [
np.power([u_ik[i] for _, u_ik in enumerate(fuzzy_matrix)], m)
for i in range(c)
]
sum_fuzzified = [np.sum(i) for i in fuzzified_assignments]
centroid_numerator = [
[np.multiply(fuzzified[cluster][index], sample)
for index, sample in enumerate(data)]
for cluster in range(c)
]
centroid = [np.sum(v, axis=0) / sum_fuzzified[i]
for i, v in enumerate(centroid_numerator)]
```
Complexity: O(ncd) with 3 nested loops
New Implementation (JAX):
```python
U_fuzz = jnp.power(U, m)
numerator = U_fuzz.T @ X
denominator = jnp.sum(U_fuzz, axis=0, keepdims=True).T
centroids = numerator / denominator
```
Complexity: O(ncd) with single matrix multiply
Speedup: ~10-50× due to:
- No loop overhead
- BLAS optimized matrix multiply
- SIMD vectorization
- GPU parallelization
Args:
X: (n, d) data matrix
U: (n, c) fuzzy membership matrix
Returns:
V: (c, d) centroid matrix
"""
# Fuzzify membership: U^m
U_fuzz = jnp.power(U, self.fuzzifier) # (n, c)
# Numerator: (c, n) @ (n, d) = (c, d)
numerator = U_fuzz.T @ X
# Denominator: (c, 1)
denominator = jnp.sum(U_fuzz, axis=0, keepdims=True).T
# Avoid division by zero
denominator = jnp.maximum(denominator, 1e-10)
centroids = numerator / denominator
return centroids
@partial(jit, static_argnums=(0,))
def _update_U(
self,
X: chex.Array,
centroids: chex.Array
) -> chex.Array:
"""
Update fuzzy membership matrix.
Mathematical Formula:
.. math::
u_{ij} = \\frac{1}{\\sum_k \\left(\\frac{d_{ij}}{d_{ik}}\\right)^{2/(m-1)}}
where :math:`d_{ij} = \\|x_i - v_j\\|` is the Euclidean distance.
Old Implementation (NumPy):
```python
for i in range(len(data)):
denominator = 0
for j in range(c):
denominator += np.power(
1 / euclidean_distance(centroids[j], data[i]),
2 / (m - 1)
)
for j in range(c):
uik_new = np.power(
1 / euclidean_distance(centroids[j], data[i]),
2 / (m - 1)
) / denominator
u_matrix[i][j] = uik_new
```
Issues:
- Double nested loop
- Distance computed twice
- In-place updates
New Implementation (JAX):
```python
D = euclidean_distance_matrix(X, centroids) # (n, c)
D = jnp.maximum(D, 1e-10) # Numerical stability
power = 2.0 / (m - 1)
ratios = (D[:, :, None] / D[:, None, :]) ** power # (n, c, c)
denominators = jnp.sum(ratios, axis=2) # (n, c)
U = 1.0 / denominators
```
Benefits:
- Single distance computation
- Vectorized operations
- Immutable (functional)
Args:
X: (n, d) data matrix
centroids: (c, d) centroid matrix
Returns:
U: (n, c) updated fuzzy membership matrix
"""
# Compute pairwise distances: (n, c)
D = self.distance_fn(X, centroids)
# Add small epsilon to avoid division by zero
D = jnp.maximum(D, 1e-10)
# Compute power for formula
power = 1.0 / (self.fuzzifier - 1)
# For each i, j: sum over k of (d_ij / d_ik)^power
# Reshape for broadcasting: (n, c, 1) / (n, 1, c) = (n, c, c)
ratios = jnp.power(D[:, :, None] / D[:, None, :], power)
# Sum over k dimension: (n, c, c) -> (n, c)
denominators = jnp.sum(ratios, axis=2)
# U_ij = 1 / denominator_ij
U = 1.0 / denominators
# Normalize rows to sum to 1 (numerical stability)
U = U / jnp.sum(U, axis=1, keepdims=True)
return U
@partial(jit, static_argnums=(0,))
def _compute_objective(
self,
X: chex.Array,
centroids: chex.Array,
U: chex.Array
) -> chex.Array:
"""
Compute FCM objective function.
Mathematical Formula:
.. math::
J = \\sum_i \\sum_j u_{ij}^m \\|x_i - v_j\\|^2
Vectorized Implementation:
:math:`J = \\sum(U^m \\odot D^2)`
where :math:`\\odot` is the element-wise product.
Old Implementation (NumPy):
```python
objective = np.sum([
[squared_euclidean_distance(data[i], centroids[j]) *
np.power(u_matrix[i][j], m)
for i in range(len(data))]
for j in range(c)
])
```
Issues:
- Nested loops
- Redundant distance computation
New Implementation (JAX):
```python
D_sq = squared_euclidean_distance_matrix(X, centroids)
U_fuzz = jnp.power(U, m)
objective = jnp.sum(U_fuzz * D_sq)
```
Benefits:
- Single pass
- Reuses distance matrix
- Element-wise operations
Args:
X: (n, d) data matrix
centroids: (c, d) centroids
U: (n, c) fuzzy membership
Returns:
J: scalar objective value
"""
# Squared distances: (n, c)
D_sq = self.distance_fn(X, centroids)
# Fuzzified membership: (n, c)
U_fuzz = jnp.power(U, self.fuzzifier)
# Element-wise multiply and sum all
objective = jnp.sum(U_fuzz * D_sq)
return objective
@partial(jit, static_argnums=(0,))
def _check_convergence(
self,
centroids_old: chex.Array,
centroids_new: chex.Array
) -> chex.Array:
"""
Check if centroids have converged.
Formula: :math:`\\|V_{new} - V_{old}\\|_F < \\epsilon`
Uses Frobenius norm for stability.
Args:
centroids_old: Previous centroids
centroids_new: Current centroids
Returns:
Boolean scalar (as JAX array)
"""
diff = jnp.linalg.norm(centroids_new - centroids_old, ord='fro')
return diff < self.epsilon
@partial(jit, static_argnums=(0,))
def _iteration_step(
self,
state: FCMState,
X: chex.Array
) -> tuple[FCMState, dict]:
"""
Single iteration of FCM algorithm.
This function is JIT-compiled and used in lax.scan for fast looping.
Algorithm:
1. Update U matrix given current centroids
2. Compute new centroids given updated U
3. Compute objective function
4. Check convergence
5. Return new state
Args:
state: Current algorithm state
X: Data matrix (passed as auxiliary data)
Returns:
new_state: Updated state
metrics: Dictionary of metrics for this iteration
"""
# Update membership matrix
U_new = self._update_U(X, state.centroids)
# Compute new centroids
centroids_new = self._compute_centroids(X, U_new)
# Compute objective
obj_new = self._compute_objective(X, centroids_new, U_new)
# Check convergence
converged = self._check_convergence(state.centroids, centroids_new)
# Create new state
new_state = FCMState(
centroids=centroids_new,
U=U_new,
objective=obj_new,
iteration=state.iteration + 1,
converged=converged
)
# Return metrics for tracking
metrics = {
'objective': obj_new,
'centroid_change': jnp.linalg.norm(centroids_new - state.centroids),
'converged': converged
}
return new_state, metrics
def _build_info(self, state, iteration):
labels = jnp.argmax(state.U, axis=1)
weights = jnp.max(state.U, axis=1)
return {
'centroids': state.centroids, 'labels': labels,
'weights': weights, 'iteration': iteration,
'objective': float(state.objective), 'max_iter': self.max_iter,
}
[docs]
def fit(self, X: jnp.ndarray, initial_centroids=None, resume=False) -> Self:
"""Fit FCM model to data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data
initial_centroids : array-like, shape (n_clusters, n_features), optional
Pre-computed centroids for warm starting
resume : bool, default=False
If True, resume from the model's current fitted state
"""
X = self._validate_input(X)
if resume and initial_centroids is not None:
raise ValueError("Cannot use both resume=True and initial_centroids")
if not jnp.all(jnp.isfinite(X)):
raise ValueError("X contains NaN or Inf values")
# Initialize
if resume:
self._check_fitted()
centroids_init = self.centroids_
U_init = self._update_U(X, centroids_init)
elif initial_centroids is not None:
centroids_init = self._validate_initial_centroids(X, initial_centroids)
U_init = self._update_U(X, centroids_init)
else:
self.key, subkey = jax.random.split(self.key)
U_init = self._initialize_U(X, subkey)
centroids_init = self._compute_centroids(X, U_init)
obj_init = self._compute_objective(X, centroids_init, U_init)
initial_state = FCMState(
centroids=centroids_init, U=U_init,
objective=obj_init, iteration=0, converged=False
)
final_state, self.history_ = self._run_training(X, initial_state)
# Store results
self.centroids_ = final_state.centroids
self.U_ = final_state.U
self.objective_ = final_state.objective
self.n_iter_ = final_state.iteration
return self
[docs]
def predict(self, X: jnp.ndarray) -> jnp.ndarray:
"""
Predict cluster labels for X.
Assigns each sample to the cluster with highest membership.
Args:
X: (n_samples, n_features) data
Returns:
labels: (n_samples,) cluster assignments (0 to n_clusters-1)
Raises:
RuntimeError: If model not fitted yet
"""
self._check_fitted()
X = jnp.asarray(X)
# Compute membership matrix
U = self._update_U(X, self.centroids_)
# Hard assignment: argmax over clusters
labels = jnp.argmax(U, axis=1)
return labels
[docs]
def predict_proba(self, X: jnp.ndarray) -> jnp.ndarray:
"""
Predict fuzzy membership for X.
Args:
X: (n_samples, n_features) data
Returns:
U: (n_samples, n_clusters) fuzzy membership matrix
Each row sums to 1, values in [0, 1]
Raises:
RuntimeError: If model not fitted yet
"""
self._check_fitted()
X = jnp.asarray(X)
return self._update_U(X, self.centroids_)
def get_objective_history(self) -> jnp.ndarray:
"""
Return objective function values across iterations.
Returns:
objectives: (max_iter,) array of objective values
Raises:
RuntimeError: If model not fitted yet
"""
if self.history_ is None:
raise RuntimeError("Model not fitted yet. Call fit() first.")
return self.history_['objective']
def get_distance_space(self, X: jnp.ndarray) -> jnp.ndarray:
"""
Compute distances from samples to cluster centroids.
Args:
X: (n_samples, n_features) data
Returns:
D: (n_samples, n_clusters) distance matrix
Raises:
RuntimeError: If model not fitted yet
"""
self._check_fitted()
X = jnp.asarray(X)
return self.distance_fn(X, self.centroids_)