"""
JAX-based Possibilistic C-Means (PCM) clustering.
This module provides a GPU-accelerated implementation of Possibilistic C-Means
using JAX for automatic differentiation and JIT compilation.
PCM extends FCM by introducing typicality values that represent the degree to which
a data point belongs to a cluster, independent of other clusters. This makes PCM
less sensitive to outliers and noise compared to FCM.
Mathematical formulation:
Objective function:
.. math::
J = \\sum_i \\sum_j t_{ij}^m \\|x_i - v_j\\|^2 + \\sum_j \\gamma_j \\sum_i (1 - t_{ij})^m
where :math:`t_{ij}` is the typicality of point :math:`x_i` to cluster :math:`j`,
:math:`v_j` is the centroid of cluster :math:`j`, :math:`m` is the fuzzifier (:math:`m > 1`),
and :math:`\\gamma_j` is a scale parameter for cluster :math:`j`.
Update equations:
.. math::
v_j = \\frac{\\sum_i t_{ij}^m x_i}{\\sum_i t_{ij}^m}
.. math::
\\gamma_j = k \\cdot \\frac{\\sum_i t_{ij}^m \\|x_i - v_j\\|^2}{\\sum_i t_{ij}^m}
.. math::
t_{ij} = \\frac{1}{1 + \\left(\\frac{\\|x_i - v_j\\|^2}{\\gamma_j}\\right)^{1/(m-1)}}
References:
Krishnapuram, R., & Keller, J. M. (1993).
A possibilistic approach to clustering.
IEEE Transactions on Fuzzy Systems, 1(2), 98-110.
"""
# Author: Nana Abeka Otoo <abekaotoo@gmail.com>
# License: MIT
from typing import NamedTuple, Self
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
import chex
from .fcm import FCM
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
class PCMState(NamedTuple):
"""Immutable state for PCM training using JAX.
Attributes:
centroids: Cluster centroids, shape (c, d)
T: Typicality matrix, shape (n, c)
gamma: Scale parameters for each cluster, shape (c,)
objective: Current objective function value
iteration: Current iteration number
converged: Whether the algorithm has converged
"""
centroids: chex.Array
T: chex.Array
gamma: chex.Array
objective: chex.Array
iteration: int
converged: bool
[docs]
class PCM(ScanFitMixin, FuzzyClusteringBase):
"""
JAX-based Possibilistic C-Means clustering with GPU acceleration.
PCM is a clustering algorithm that assigns typicality values to data points,
representing the degree to which they belong to each cluster. Unlike FCM,
the typicality of a point to one cluster is independent of its typicality
to other clusters.
Parameters
----------
fuzzifier : float, default=2.0
Fuzzification parameter (:math:`m > 1`). Higher values result in fuzzier
clusters.
k : float, default=1.0
Parameter for :math:`\\gamma` computation. Typical values are in [0.01, 1.0].
Lower values make the algorithm more sensitive to outliers.
init_method : {'fcm', 'random'}, default='fcm'
Initialization method:
- 'fcm': Initialize using FCM results (recommended)
- 'random': Random initialization
n_clusters : int
Number of clusters (must be >= 2).
max_iter : int
Maximum number of iterations.
epsilon : float
Convergence threshold.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Pairwise distance function. Default: squared Euclidean.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore centroids from the lowest-objective epoch.
Default: False.
plot_steps : bool
Whether to visualize clustering progress. Default: False.
show_confidence : bool
Whether to show confidence in visualization. Default: True.
show_pca_variance : bool
Whether to show PCA variance in visualization. Default: True.
save_plot_path : str, optional
Path to save final plot.
callbacks : list, optional
List of Callback objects for monitoring/visualization.
Attributes
----------
centroids_ : ndarray of shape (n_clusters, n_features)
Cluster centroids after fitting.
T_ : ndarray of shape (n_samples, n_clusters)
Typicality matrix after fitting.
gamma_ : ndarray of shape (n_clusters,)
Scale parameters for each cluster.
n_iter_ : int
Number of iterations run.
objective_ : float
Final objective function value.
Examples
--------
>>> import jax.numpy as jnp
>>> from prosemble.models import PCM
>>>
>>> # Generate sample data
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]])
>>>
>>> # Fit PCM model
>>> model = PCM(n_clusters=2, fuzzifier=2.0, k=1.0)
>>> model.fit(X)
>>>
>>> # Get cluster assignments
>>> labels = model.predict(X)
>>>
>>> # Get typicality values
>>> typicalities = model.predict_proba(X)
Notes
-----
- PCM is less sensitive to outliers than FCM because typicality values
are computed independently for each cluster.
- The parameter :math:`k` controls the sensitivity to outliers. Smaller values
make the algorithm more sensitive.
- Initialization from FCM (init_method='fcm') is recommended as it provides
better starting points than random initialization.
- All computations are JIT-compiled and can run on GPU if available.
"""
_hyperparams = ('fuzzifier', 'k', 'init_method')
_fitted_array_names = ('T_', 'gamma_')
def __init__(
self,
n_clusters: int,
fuzzifier: float = 2.0,
k: float = 1.0,
max_iter: int = 100,
epsilon: float = 1e-5,
init_method: str = 'fcm',
random_seed: int | None = None,
distance_fn=None,
patience: int | None = None,
restore_best: bool = False,
plot_steps: bool = False,
show_confidence: bool = True,
show_pca_variance: bool = True,
save_plot_path: str = None,
callbacks=None,
):
# Validate model-specific parameters
if fuzzifier <= 1.0:
raise ValueError(f"fuzzifier must be > 1.0, got {fuzzifier}")
if k <= 0:
raise ValueError(f"k must be > 0, got {k}")
if init_method not in ['fcm', 'random']:
raise ValueError(f"init_method must be 'fcm' or 'random', got {init_method}")
# Resolve default seed
if random_seed is None:
random_seed = 42
super().__init__(
n_clusters=n_clusters,
max_iter=max_iter,
epsilon=epsilon,
random_seed=random_seed,
distance_fn=distance_fn,
patience=patience,
restore_best=restore_best,
plot_steps=plot_steps,
show_confidence=show_confidence,
show_pca_variance=show_pca_variance,
save_plot_path=save_plot_path,
callbacks=callbacks,
)
self.fuzzifier = fuzzifier
self.k = k
self.init_method = init_method
# Fitted attributes
self.T_ = None
self.gamma_ = None
self._objective_history = None
def _initialize_from_fcm(self, X: chex.Array) -> tuple[chex.Array, chex.Array]:
"""
Initialize centroids and typicality matrix using FCM.
Args:
X: Data matrix of shape (n, d)
Returns:
centroids: Initial centroids of shape (c, d)
T: Initial typicality matrix of shape (n, c)
"""
# Run FCM to get initial centroids and membership matrix
fcm = FCM(
n_clusters=self.n_clusters,
fuzzifier=self.fuzzifier,
max_iter=self.max_iter,
epsilon=self.epsilon,
random_seed=int(self.key[0]),
distance_fn=self.distance_fn,
)
fcm.fit(X)
# Use FCM centroids and membership as initial typicality
centroids = fcm.centroids_
T = fcm.U_ # Use membership as initial typicality
return centroids, T
def _initialize_random(self, X: chex.Array) -> tuple[chex.Array, chex.Array]:
"""
Initialize centroids and typicality matrix randomly.
Args:
X: Data matrix of shape (n, d)
Returns:
centroids: Initial centroids of shape (c, d)
T: Initial typicality matrix of shape (n, c)
"""
n_samples, n_features = X.shape
# Random centroids from data points
key1, key2, self.key = jax.random.split(self.key, 3)
indices = jax.random.choice(key1, n_samples, shape=(self.n_clusters,), replace=False)
centroids = X[indices]
# Random typicality matrix (using Dirichlet distribution for valid probabilities)
alpha = jnp.ones(self.n_clusters)
T = jax.random.dirichlet(key2, alpha, shape=(n_samples,))
return centroids, T
@partial(jit, static_argnums=(0,))
def _compute_centroids(self, X: chex.Array, T: chex.Array) -> chex.Array:
"""
Compute cluster centroids from typicality matrix.
Vectorized computation:
.. math::
v_j = \\frac{\\sum_i t_{ij}^m x_i}{\\sum_i t_{ij}^m}
Using matrix operations:
V = (T^m)^T @ X / sum(T^m, axis=0)
Args:
X: Data matrix of shape (n, d)
T: Typicality matrix of shape (n, c)
Returns:
centroids: Cluster centroids of shape (c, d)
"""
# Fuzzify typicality matrix: T^m
T_fuzz = jnp.power(T, self.fuzzifier) # (n, c)
# Numerator: (T^m)^T @ X = (c, n) @ (n, d) = (c, d)
numerator = T_fuzz.T @ X
# Denominator: sum of each column of T^m = (c,)
denominator = jnp.sum(T_fuzz, axis=0, keepdims=True).T # (c, 1)
# Compute centroids with numerical stability
centroids = numerator / jnp.maximum(denominator, 1e-10)
return centroids
@partial(jit, static_argnums=(0,))
def _compute_gamma(self, X: chex.Array, T: chex.Array, centroids: chex.Array) -> chex.Array:
"""
Compute :math:`\\gamma` parameters for each cluster.
Vectorized computation:
.. math::
\\gamma_j = k \\cdot \\frac{\\sum_i t_{ij}^m \\|x_i - v_j\\|^2}{\\sum_i t_{ij}^m}
Args:
X: Data matrix of shape (n, d)
T: Typicality matrix of shape (n, c)
centroids: Cluster centroids of shape (c, d)
Returns:
gamma: Scale parameters of shape (c,)
"""
# Compute squared distances: (n, c)
D_sq = self.distance_fn(X, centroids)
# Fuzzify typicality: (n, c)
T_fuzz = jnp.power(T, self.fuzzifier)
# Weighted distances: element-wise multiply and sum over samples
# numerator = sum_i t_ij^m ||x_i - v_j||^2
numerator = jnp.sum(T_fuzz * D_sq, axis=0) # (c,)
# denominator = sum_i t_ij^m
denominator = jnp.sum(T_fuzz, axis=0) # (c,)
# Compute gamma with numerical stability
gamma = self.k * numerator / jnp.maximum(denominator, 1e-10)
return gamma
@partial(jit, static_argnums=(0,))
def _update_typicality(
self,
X: chex.Array,
centroids: chex.Array,
gamma: chex.Array
) -> chex.Array:
"""
Update typicality matrix.
Vectorized computation:
.. math::
t_{ij} = \\frac{1}{1 + \\left(\\frac{\\|x_i - v_j\\|^2}{\\gamma_j}\\right)^{1/(m-1)}}
Args:
X: Data matrix of shape (n, d)
centroids: Cluster centroids of shape (c, d)
gamma: Scale parameters of shape (c,)
Returns:
T: Updated typicality matrix of shape (n, c)
"""
# Compute squared distances: (n, c)
D_sq = self.distance_fn(X, centroids)
# Compute exponent
exponent = 1.0 / (self.fuzzifier - 1.0)
# Compute denominator: (D^2_ij / gamma_j)^(1/(m-1))
# Add small epsilon to gamma to avoid division by zero
ratio = D_sq / jnp.maximum(gamma[jnp.newaxis, :], 1e-10) # (n, c)
denominator = jnp.power(ratio, exponent) # (n, c)
# Compute typicality with numerical stability
T = 1.0 / (1.0 + denominator)
# Clip to valid range [0, 1]
T = jnp.clip(T, 0.0, 1.0)
return T
@partial(jit, static_argnums=(0,))
def _compute_objective(
self,
X: chex.Array,
centroids: chex.Array,
T: chex.Array,
gamma: chex.Array
) -> chex.Array:
"""
Compute PCM objective function.
.. math::
J = \\sum_i \\sum_j t_{ij}^m \\|x_i - v_j\\|^2 + \\sum_j \\gamma_j \\sum_i (1 - t_{ij})^m
Args:
X: Data matrix of shape (n, d)
centroids: Cluster centroids of shape (c, d)
T: Typicality matrix of shape (n, c)
gamma: Scale parameters of shape (c,)
Returns:
objective: Scalar objective value
"""
# First term: sum_i sum_j t_ij^m ||x_i - v_j||^2
D_sq = self.distance_fn(X, centroids) # (n, c)
T_fuzz = jnp.power(T, self.fuzzifier) # (n, c)
term1 = jnp.sum(T_fuzz * D_sq)
# Second term: sum_j gamma_j sum_i (1 - t_ij)^m
one_minus_T = 1.0 - T # (n, c)
one_minus_T_fuzz = jnp.power(one_minus_T, self.fuzzifier) # (n, c)
sum_per_cluster = jnp.sum(one_minus_T_fuzz, axis=0) # (c,)
term2 = jnp.sum(gamma * sum_per_cluster)
objective = term1 + term2
return objective
@partial(jit, static_argnums=(0,))
def _iteration_step(self, state: PCMState, X: chex.Array) -> tuple[PCMState, dict]:
"""
Perform a single iteration of PCM.
Args:
state: Current PCM state
X: Data matrix of shape (n, d)
Returns:
new_state: Updated PCM state
metrics: Dictionary of iteration metrics
"""
# Update centroids
centroids_new = self._compute_centroids(X, state.T)
# Update gamma
gamma_new = self._compute_gamma(X, state.T, centroids_new)
# Update typicality
T_new = self._update_typicality(X, centroids_new, gamma_new)
# Compute objective
objective_new = self._compute_objective(X, centroids_new, T_new, gamma_new)
# Check convergence (centroid change)
centroid_change = jnp.linalg.norm(centroids_new - state.centroids)
converged = centroid_change < self.epsilon
new_state = PCMState(
centroids=centroids_new,
T=T_new,
gamma=gamma_new,
objective=objective_new,
iteration=state.iteration + 1,
converged=converged
)
metrics = {
'objective': new_state.objective,
'centroid_change': centroid_change,
'converged': new_state.converged,
}
return new_state, metrics
def _build_info(self, state, iteration):
import numpy as np
labels = jnp.argmax(state.T, axis=1)
weights = jnp.max(state.T, axis=1)
max_typicality = np.asarray(weights)
outlier_threshold = 0.5
n_outliers = int(np.sum(max_typicality < outlier_threshold))
return {
'centroids': state.centroids, 'labels': labels,
'weights': weights, 'iteration': iteration,
'objective': float(state.objective), 'max_iter': self.max_iter,
'outlier_count': n_outliers,
}
[docs]
def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self:
"""Fit PCM clustering model to data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data
initial_centroids : array-like, shape (n_clusters, n_features), optional
Pre-computed centroids for warm starting
resume : bool, default=False
If True, resume from the model's current fitted state
"""
X = self._validate_input(X)
if resume and initial_centroids is not None:
raise ValueError("Cannot use both resume=True and initial_centroids")
# Initialize
if resume:
self._check_fitted()
centroids_init = self.centroids_
T_init = self.T_
gamma_init = self.gamma_
elif initial_centroids is not None:
centroids_init = self._validate_initial_centroids(X, initial_centroids)
T_init = jnp.ones((X.shape[0], self.n_clusters)) / self.n_clusters
gamma_init = self._compute_gamma(X, T_init, centroids_init)
T_init = self._update_typicality(X, centroids_init, gamma_init)
else:
if self.init_method == 'fcm':
centroids_init, T_init = self._initialize_from_fcm(X)
else:
centroids_init, T_init = self._initialize_random(X)
if not resume:
gamma_init = self._compute_gamma(X, T_init, centroids_init)
objective_init = self._compute_objective(X, centroids_init, T_init, gamma_init)
initial_state = PCMState(
centroids=centroids_init, T=T_init, gamma=gamma_init,
objective=objective_init, iteration=0, converged=False
)
final_state, self.history_ = self._run_training(X, initial_state)
# Store results
self.centroids_ = final_state.centroids
self.T_ = final_state.T
self.gamma_ = final_state.gamma
self.n_iter_ = int(final_state.iteration)
self.objective_ = final_state.objective
self.objective_history_ = self.history_['objective']
return self
[docs]
def predict(self, X: chex.Array) -> chex.Array:
"""
Predict cluster labels for data.
Args:
X: Data matrix of shape (n_samples, n_features)
Returns:
labels: Cluster labels of shape (n_samples,)
Raises:
ValueError: If model has not been fitted
"""
self._check_fitted()
chex.assert_rank(X, 2)
# Compute distances to centroids
D = self.distance_fn(X, self.centroids_) # (n, c)
# Assign to nearest centroid
labels = jnp.argmin(D, axis=1)
return labels
[docs]
def predict_proba(self, X: chex.Array) -> chex.Array:
"""
Predict typicality values for data.
Args:
X: Data matrix of shape (n_samples, n_features)
Returns:
T: Typicality matrix of shape (n_samples, n_clusters)
Raises:
ValueError: If model has not been fitted
"""
self._check_fitted()
chex.assert_rank(X, 2)
# Compute typicality using learned gamma
T = self._update_typicality(X, self.centroids_, self.gamma_)
return T