"""
JAX-based Kernel Possibilistic C-Means (KPCM) clustering implementation.
This module provides a GPU-accelerated implementation of KPCM using JAX
with JIT compilation for high performance.
"""
from typing import NamedTuple, Self
from functools import partial
import jax
import jax.numpy as jnp
import chex
from jax import jit
from prosemble.core.kernel import batch_gaussian_kernel
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
from prosemble.models.kfcm import KFCM
class KPCMState(NamedTuple):
"""Immutable state for KPCM iteration.
Attributes:
centroids: Cluster centroids, shape (n_clusters, n_features)
T: Typicality matrix, shape (n_samples, n_clusters)
gamma: Scale parameters, shape (n_clusters,)
objective: Current objective function value
iteration: Current iteration number
converged: Whether algorithm has converged
"""
centroids: chex.Array
T: chex.Array
gamma: chex.Array
objective: chex.Array
iteration: int
converged: bool
[docs]
class KPCM(ScanFitMixin, FuzzyClusteringBase):
"""
Kernel Possibilistic C-Means clustering with JAX.
KPCM extends PCM to kernel space using Gaussian kernel, allowing handling
of non-linearly separable data while maintaining possibilistic properties.
Kernel:
.. math::
K(x, y) = \\exp\\left(-\\frac{\\|x - y\\|^2}{\\sigma^2}\\right)
Kernel distance:
.. math::
d_K(x, v) = 2(1 - K(x, v))
Algorithm:
1. Initialize using KFCM
2. Compute :math:`\\gamma` parameters
3. Update typicality matrix :math:`T`
4. Update centroids (kernel-weighted)
5. Repeat until convergence
Objective function:
.. math::
J = \\sum_i \\sum_j t_{ij}^m \\cdot d_K(x_i, v_j) + \\sum_j \\gamma_j \\sum_i (1 - t_{ij})^m
Parameters
----------
fuzzifier : float, default=2.0
Fuzziness parameter (must be > 1.0).
k : float, default=1.0
Scaling parameter for :math:`\\gamma` (must be > 0).
sigma : float, default=1.0
Kernel bandwidth parameter (must be > 0).
init_method : {'kfcm'}, default='kfcm'
Initialization method.
n_clusters : int
Number of clusters (must be >= 2).
max_iter : int
Maximum number of iterations.
epsilon : float
Convergence threshold.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Pairwise distance function. Default: squared Euclidean.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore centroids from the lowest-objective epoch.
Default: False.
plot_steps : bool
Whether to visualize clustering progress. Default: False.
show_confidence : bool
Whether to show confidence in visualization. Default: True.
show_pca_variance : bool
Whether to show PCA variance in visualization. Default: True.
save_plot_path : str, optional
Path to save final plot.
callbacks : list, optional
List of Callback objects for monitoring/visualization.
Attributes
----------
centroids_ : array, shape (n_clusters, n_features)
Final cluster centroids
T_ : array, shape (n_samples, n_clusters)
Final typicality matrix
gamma_ : array, shape (n_clusters,)
Final scale parameters
n_iter_ : int
Number of iterations until convergence
objective_ : float
Final objective function value
objective_history_ : array
Objective values at each iteration
Examples
--------
>>> import jax.numpy as jnp
>>> from prosemble.models import KPCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = KPCM(n_clusters=2, sigma=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
"""
_hyperparams = ('fuzzifier', 'sigma', 'init_method')
_fitted_array_names = ('T_', 'gamma_')
def __init__(
self,
n_clusters: int,
fuzzifier: float = 2.0,
k: float = 1.0,
sigma: float = 1.0,
max_iter: int = 100,
epsilon: float = 1e-5,
init_method: str = 'kfcm',
random_seed: int = 42,
distance_fn=None,
patience: int | None = None,
restore_best: bool = False,
plot_steps: bool = False,
show_confidence: bool = True,
show_pca_variance: bool = True,
save_plot_path: str | None = None,
callbacks=None,
):
# Validate model-specific parameters
if fuzzifier <= 1.0:
raise ValueError("fuzzifier must be > 1.0")
if k <= 0:
raise ValueError("k must be > 0")
if sigma <= 0:
raise ValueError("sigma must be > 0")
if init_method != 'kfcm':
raise ValueError("init_method must be 'kfcm' for KPCM")
super().__init__(
n_clusters=n_clusters,
max_iter=max_iter,
epsilon=epsilon,
random_seed=random_seed,
distance_fn=distance_fn,
patience=patience,
restore_best=restore_best,
plot_steps=plot_steps,
show_confidence=show_confidence,
show_pca_variance=show_pca_variance,
save_plot_path=save_plot_path,
callbacks=callbacks,
)
self.fuzzifier = fuzzifier
self.k = k
self.sigma = sigma
self.init_method = init_method
# Model-specific fitted attributes
self.T_ = None
self.gamma_ = None
def _initialize(self, X: chex.Array):
"""Initialize using KFCM."""
# Use KFCM to initialize
kfcm = KFCM(
n_clusters=self.n_clusters,
fuzzifier=self.fuzzifier,
sigma=self.sigma,
max_iter=self.max_iter,
epsilon=self.epsilon,
random_seed=self.random_seed,
plot_steps=False
)
kfcm.fit(X)
U = kfcm.U_
centroids = kfcm.centroids_
# Initialize gamma from U
gamma = self._compute_gamma(X, U, centroids)
return U, centroids, gamma
@partial(jit, static_argnums=(0,))
def _compute_gamma(
self, X: chex.Array, U: chex.Array, centroids: chex.Array
) -> chex.Array:
"""Compute scale parameters.
.. math::
\\gamma_j = k \\cdot \\frac{\\sum_i u_{ij}^m \\cdot d_K(x_i, v_j)}{\\sum_i u_{ij}^m}
"""
# Compute kernel matrix
K = batch_gaussian_kernel(X, centroids, self.sigma)
# Kernel distance
kernel_dist = 2.0 * (1.0 - K)
# Fuzzify U
U_fuzz = jnp.power(U, self.fuzzifier)
# Compute gamma
numerator = jnp.sum(U_fuzz * kernel_dist, axis=0)
denominator = jnp.sum(U_fuzz, axis=0)
gamma = self.k * numerator / denominator
return gamma
@partial(jit, static_argnums=(0,))
def _update_T(
self, X: chex.Array, centroids: chex.Array, gamma: chex.Array
) -> chex.Array:
"""Update typicality matrix.
.. math::
t_{ij} = \\frac{1}{1 + \\left(\\frac{d_K(x_i, v_j)}{\\gamma_j}\\right)^{1/(m-1)}}
"""
# Compute kernel distance
K = batch_gaussian_kernel(X, centroids, self.sigma)
kernel_dist = 2.0 * (1.0 - K)
kernel_dist = jnp.maximum(kernel_dist, 1e-10)
# Compute power
power = 1.0 / (self.fuzzifier - 1.0)
# Compute typicality
ratio = kernel_dist / gamma[None, :]
T = 1.0 / (1.0 + jnp.power(ratio, power))
return T
@partial(jit, static_argnums=(0,))
def _compute_centroids(
self, X: chex.Array, T: chex.Array, centroids: chex.Array
) -> chex.Array:
"""Compute kernel-weighted centroids.
.. math::
v_j = \\frac{\\sum_i t_{ij}^m \\cdot K(x_i, v_j) \\cdot x_i}{\\sum_i t_{ij}^m \\cdot K(x_i, v_j)}
"""
# Compute kernel matrix
K = batch_gaussian_kernel(X, centroids, self.sigma)
# Fuzzify T
T_fuzz = jnp.power(T, self.fuzzifier)
# Kernel weights
weights = T_fuzz * K
# Compute centroids
numerator = weights.T @ X
denominator = jnp.sum(weights, axis=0, keepdims=True).T
denominator = jnp.maximum(denominator, 1e-10)
centroids_new = numerator / denominator
return centroids_new
@partial(jit, static_argnums=(0,))
def _compute_objective(
self, X: chex.Array, T: chex.Array, centroids: chex.Array, gamma: chex.Array
) -> chex.Array:
"""Compute KPCM objective function.
.. math::
J = \\sum_i \\sum_j t_{ij}^m \\cdot d_K(x_i, v_j) + \\sum_j \\gamma_j \\sum_i (1 - t_{ij})^m
"""
# Compute kernel distance
K = batch_gaussian_kernel(X, centroids, self.sigma)
kernel_dist = 2.0 * (1.0 - K)
# Fuzzify T
T_fuzz = jnp.power(T, self.fuzzifier)
# First term
term1 = jnp.sum(T_fuzz * kernel_dist)
# Second term
one_minus_T = 1.0 - T
one_minus_T_fuzz = jnp.power(one_minus_T, self.fuzzifier)
inner_sum = jnp.sum(one_minus_T_fuzz, axis=0)
term2 = jnp.sum(gamma * inner_sum)
objective = term1 + term2
return objective
@partial(jit, static_argnums=(0,))
def _iteration_step(self, state: KPCMState, X: chex.Array) -> tuple[KPCMState, dict]:
"""Single KPCM iteration step."""
# Update T
T_new = self._update_T(X, state.centroids, state.gamma)
# Update centroids
centroids_new = self._compute_centroids(X, T_new, state.centroids)
# Compute objective
objective = self._compute_objective(X, T_new, centroids_new, state.gamma)
# Check convergence
centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro')
converged = centroid_change <= self.epsilon
new_state = KPCMState(
centroids=centroids_new,
T=T_new,
gamma=state.gamma,
objective=objective,
iteration=state.iteration + 1,
converged=converged
)
metrics = {
'objective': new_state.objective,
'centroid_change': centroid_change,
'converged': new_state.converged,
}
return new_state, metrics
def _build_info(self, state, iteration):
labels = jnp.argmax(state.T, axis=1)
weights = jnp.max(state.T, axis=1)
return {
'centroids': state.centroids, 'labels': labels,
'weights': weights, 'iteration': iteration,
'objective': float(state.objective), 'max_iter': self.max_iter,
}
[docs]
def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self:
"""Fit KPCM model to data."""
if resume and initial_centroids is not None:
raise ValueError("Cannot use both resume=True and initial_centroids")
X = self._validate_input(X)
if resume:
self._check_fitted()
centroids_init = self.centroids_
T_init = self.T_
gamma_init = self.gamma_
elif initial_centroids is not None:
centroids_init = self._validate_initial_centroids(X, initial_centroids)
# Derive gamma from uniform U, then T
U_uniform = jnp.ones((X.shape[0], self.n_clusters)) / self.n_clusters
gamma_init = self._compute_gamma(X, U_uniform, centroids_init)
T_init = self._update_T(X, centroids_init, gamma_init)
else:
# Initialize
U_init, centroids_init, gamma_init = self._initialize(X)
T_init = self._update_T(X, centroids_init, gamma_init)
initial_objective = self._compute_objective(X, T_init, centroids_init, gamma_init)
initial_state = KPCMState(
centroids=centroids_init, T=T_init, gamma=gamma_init,
objective=initial_objective, iteration=0, converged=False
)
final_state, self.history_ = self._run_training(X, initial_state)
# Store results
self.centroids_ = final_state.centroids
self.T_ = final_state.T
self.gamma_ = final_state.gamma
self.n_iter_ = int(final_state.iteration)
self.objective_ = float(final_state.objective)
self.objective_history_ = self.history_['objective']
return self
[docs]
def predict(self, X: chex.Array) -> chex.Array:
"""Predict cluster labels for new data."""
self._check_fitted()
T = self._update_T(X, self.centroids_, self.gamma_)
labels = jnp.argmax(T, axis=1)
return labels
[docs]
def predict_proba(self, X: chex.Array) -> chex.Array:
"""Predict typicality values."""
self._check_fitted()
X = jnp.asarray(X)
T = self._update_T(X, self.centroids_, self.gamma_)
return T