"""
JAX-based Kernel Fuzzy C-Means (KFCM) clustering implementation.
This module provides a GPU-accelerated implementation of KFCM using JAX
with JIT compilation for high performance.
"""
from typing import NamedTuple, Self
from functools import partial
import jax
import jax.numpy as jnp
import chex
from jax import jit, lax
from prosemble.core.kernel import batch_gaussian_kernel
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
class KFCMState(NamedTuple):
"""Immutable state for KFCM iteration.
Attributes:
centroids: Cluster centroids, shape (n_clusters, n_features)
U: Fuzzy membership matrix, shape (n_samples, n_clusters)
objective: Current objective function value
iteration: Current iteration number
converged: Whether algorithm has converged
"""
centroids: chex.Array
U: chex.Array
objective: chex.Array
iteration: int
converged: bool
[docs]
class KFCM(ScanFitMixin, FuzzyClusteringBase):
"""
Kernel Fuzzy C-Means clustering with JAX.
KFCM uses a Gaussian kernel to map data into a high-dimensional feature space
where clustering is performed. This allows handling non-linearly separable data.
Kernel:
.. math::
K(x, y) = \\exp\\left(-\\frac{\\|x - y\\|^2}{\\sigma^2}\\right)
Kernel distance in feature space:
.. math::
\\|\\varphi(x) - \\varphi(y)\\|^2 = 2(1 - K(x, y))
Algorithm:
1. Initialize :math:`U` randomly
2. Update centroids (kernel-weighted)
3. Update :math:`U` using kernel distance
4. Repeat until convergence
Objective function:
.. math::
J = 2 \\sum_i \\sum_j u_{ij}^m (1 - K(x_i, v_j))
Parameters
----------
fuzzifier : float, default=2.0
Fuzziness parameter (must be > 1.0).
sigma : float, default=1.0
Kernel bandwidth parameter (must be > 0).
init_method : {'random'}, default='random'
Initialization method.
n_clusters : int
Number of clusters (must be >= 2).
max_iter : int
Maximum number of iterations.
epsilon : float
Convergence threshold.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Pairwise distance function. Default: squared Euclidean.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore centroids from the lowest-objective epoch.
Default: False.
plot_steps : bool
Whether to visualize clustering progress. Default: False.
show_confidence : bool
Whether to show confidence in visualization. Default: True.
show_pca_variance : bool
Whether to show PCA variance in visualization. Default: True.
save_plot_path : str, optional
Path to save final plot.
callbacks : list, optional
List of Callback objects for monitoring/visualization.
Attributes
----------
centroids_ : array, shape (n_clusters, n_features)
Final cluster centroids
U_ : array, shape (n_samples, n_clusters)
Final fuzzy membership matrix
n_iter_ : int
Number of iterations until convergence
objective_ : float
Final objective function value
objective_history_ : array
Objective values at each iteration
Examples
--------
>>> import jax.numpy as jnp
>>> from prosemble.models import KFCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = KFCM(n_clusters=2, sigma=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
"""
_hyperparams = ('fuzzifier', 'sigma', 'init_method')
_fitted_array_names = ('U_',)
def __init__(
self,
n_clusters: int,
fuzzifier: float = 2.0,
sigma: float = 1.0,
max_iter: int = 100,
epsilon: float = 1e-5,
init_method: str = 'random',
random_seed: int = 42,
distance_fn=None,
patience: int | None = None,
restore_best: bool = False,
plot_steps: bool = False,
show_confidence: bool = True,
show_pca_variance: bool = True,
save_plot_path: str | None = None,
callbacks=None,
):
# Validate model-specific parameters
if fuzzifier <= 1.0:
raise ValueError("fuzzifier must be > 1.0")
if sigma <= 0:
raise ValueError("sigma must be > 0")
if init_method != 'random':
raise ValueError("init_method must be 'random' for KFCM")
super().__init__(
n_clusters=n_clusters,
max_iter=max_iter,
epsilon=epsilon,
random_seed=random_seed,
distance_fn=distance_fn,
patience=patience,
restore_best=restore_best,
plot_steps=plot_steps,
show_confidence=show_confidence,
show_pca_variance=show_pca_variance,
save_plot_path=save_plot_path,
callbacks=callbacks,
)
self.fuzzifier = fuzzifier
self.sigma = sigma
self.init_method = init_method
# Model-specific fitted attributes
self.U_ = None
def _initialize(self, X: chex.Array):
"""Initialize :math:`U` matrix and centroids."""
n_samples = X.shape[0]
# Random U matrix (Dirichlet distribution ensures row sums = 1)
alpha = jnp.ones(self.n_clusters)
U = jax.random.dirichlet(self.key, alpha, shape=(n_samples,))
# Random centroids from data
indices = jax.random.choice(
self.key, n_samples, shape=(self.n_clusters,), replace=False
)
centroids = X[indices]
return U, centroids
@partial(jit, static_argnums=(0,))
def _compute_centroids(
self, X: chex.Array, U: chex.Array, centroids: chex.Array
) -> chex.Array:
"""Compute kernel-weighted centroids.
.. math::
v_j = \\frac{\\sum_i u_{ij}^m \\cdot K(x_i, v_j) \\cdot x_i}{\\sum_i u_{ij}^m \\cdot K(x_i, v_j)}
"""
# Compute kernel matrix K(X, centroids)
K = batch_gaussian_kernel(X, centroids, self.sigma) # (n_samples, n_clusters)
# Fuzzify U
U_fuzz = jnp.power(U, self.fuzzifier) # (n_samples, n_clusters)
# Kernel weights
weights = U_fuzz * K # (n_samples, n_clusters)
# Compute centroids
numerator = weights.T @ X # (n_clusters, n_features)
denominator = jnp.sum(weights, axis=0, keepdims=True).T # (n_clusters, 1)
denominator = jnp.maximum(denominator, 1e-10)
centroids_new = numerator / denominator
return centroids_new
@partial(jit, static_argnums=(0,))
def _update_U(self, X: chex.Array, centroids: chex.Array) -> chex.Array:
"""Update fuzzy membership matrix using kernel distance.
.. math::
u_{ij} = \\frac{1}{\\sum_k \\left(\\frac{1 - K(x_i, v_j)}{1 - K(x_i, v_k)}\\right)^{1/(m-1)}}
"""
# Compute kernel matrix
K = batch_gaussian_kernel(X, centroids, self.sigma) # (n_samples, n_clusters)
# Kernel distance: 1 - K(x, v)
kernel_dist = 1.0 - K # (n_samples, n_clusters)
kernel_dist = jnp.maximum(kernel_dist, 1e-10)
# Compute power
power = 1.0 / (self.fuzzifier - 1.0)
# Compute distance ratios
def compute_membership_row(distances_i):
# distances_i: (n_clusters,)
ratios = distances_i[:, None] / distances_i[None, :] # (n_clusters, n_clusters)
powered_ratios = jnp.power(ratios, power)
denominators = jnp.sum(powered_ratios, axis=1) # (n_clusters,)
memberships = 1.0 / denominators
return memberships
U = jax.vmap(compute_membership_row)(kernel_dist) # (n_samples, n_clusters)
# Normalize
U = U / jnp.sum(U, axis=1, keepdims=True)
return U
@partial(jit, static_argnums=(0,))
def _compute_objective(
self, X: chex.Array, U: chex.Array, centroids: chex.Array
) -> chex.Array:
"""Compute KFCM objective function.
.. math::
J = 2 \\sum_i \\sum_j u_{ij}^m (1 - K(x_i, v_j))
"""
# Compute kernel matrix
K = batch_gaussian_kernel(X, centroids, self.sigma)
# Kernel distance
kernel_dist = 1.0 - K
# Fuzzify U
U_fuzz = jnp.power(U, self.fuzzifier)
# Weighted sum
objective = 2.0 * jnp.sum(U_fuzz * kernel_dist)
return objective
@partial(jit, static_argnums=(0,))
def _iteration_step(
self, state: KFCMState, X: chex.Array
) -> tuple[KFCMState, dict]:
"""Single KFCM iteration step."""
# Update centroids
centroids_new = self._compute_centroids(X, state.U, state.centroids)
# Update U
U_new = self._update_U(X, centroids_new)
# Compute objective
objective = self._compute_objective(X, U_new, centroids_new)
# Check convergence
centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro')
converged = centroid_change <= self.epsilon
new_state = KFCMState(
centroids=centroids_new,
U=U_new,
objective=objective,
iteration=state.iteration + 1,
converged=converged
)
metrics = {
'objective': objective,
'centroid_change': centroid_change,
'converged': converged
}
return new_state, metrics
def _build_info(self, state, iteration):
labels = jnp.argmax(state.U, axis=1)
weights = jnp.max(state.U, axis=1)
return {
'centroids': state.centroids, 'labels': labels,
'weights': weights, 'iteration': iteration,
'objective': float(state.objective), 'max_iter': self.max_iter,
}
[docs]
def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self:
"""Fit KFCM model to data."""
if resume and initial_centroids is not None:
raise ValueError("Cannot use both resume=True and initial_centroids")
X = self._validate_input(X)
if resume:
self._check_fitted()
centroids_init = self.centroids_
U_init = self._update_U(X, centroids_init)
elif initial_centroids is not None:
centroids_init = self._validate_initial_centroids(X, initial_centroids)
U_init = self._update_U(X, centroids_init)
else:
U_init, centroids_init = self._initialize(X)
initial_objective = self._compute_objective(X, U_init, centroids_init)
initial_state = KFCMState(
centroids=centroids_init, U=U_init,
objective=initial_objective, iteration=0, converged=False
)
final_state, self.history_ = self._run_training(X, initial_state)
# Store results
self.centroids_ = final_state.centroids
self.U_ = final_state.U
self.n_iter_ = int(final_state.iteration)
self.objective_ = float(final_state.objective)
self.objective_history_ = self.history_['objective']
return self
[docs]
def predict(self, X: chex.Array) -> chex.Array:
"""Predict cluster labels for new data."""
self._check_fitted()
U = self._update_U(X, self.centroids_)
labels = jnp.argmax(U, axis=1)
return labels
[docs]
def predict_proba(self, X: chex.Array) -> chex.Array:
"""Predict fuzzy membership probabilities."""
self._check_fitted()
X = jnp.asarray(X)
U = self._update_U(X, self.centroids_)
return U