"""
JAX-based Kernel Fuzzy Possibilistic C-Means (KFPCM) clustering implementation.
"""
from typing import NamedTuple, Self
from functools import partial
import jax
import jax.numpy as jnp
import chex
from jax import jit
from prosemble.core.kernel import batch_gaussian_kernel
from prosemble.models.kfcm import KFCM
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
class KFPCMState(NamedTuple):
centroids: chex.Array
U: chex.Array
T: chex.Array
objective: chex.Array
iteration: int
converged: bool
[docs]
class KFPCM(ScanFitMixin, FuzzyClusteringBase):
"""Kernel Fuzzy Possibilistic C-Means with JAX.
KFPCM maintains two matrices (:math:`U` and :math:`T`) in kernel space. :math:`U` has row-sum-to-1
constraint (standard FCM), while :math:`T` has column-sum-to-1 constraint per the
original Pal, Pal & Bezdek (1997) FPCM formulation.
Parameters
----------
fuzzifier : float, default=2.0
Fuzziness parameter for :math:`U` matrix (must be > 1.0).
eta : float, default=2.0
Fuzziness parameter for :math:`T` matrix (must be > 1.0).
sigma : float, default=1.0
Kernel bandwidth parameter (must be > 0).
init_method : {'kfcm', 'random'}, default='kfcm'
Initialization method.
n_clusters : int
Number of clusters (must be >= 2).
max_iter : int
Maximum number of iterations.
epsilon : float
Convergence threshold.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Pairwise distance function. Default: squared Euclidean.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore centroids from the lowest-objective epoch.
Default: False.
plot_steps : bool
Whether to visualize clustering progress. Default: False.
show_confidence : bool
Whether to show confidence in visualization. Default: True.
show_pca_variance : bool
Whether to show PCA variance in visualization. Default: True.
save_plot_path : str, optional
Path to save final plot.
callbacks : list, optional
List of Callback objects for monitoring/visualization.
"""
_hyperparams = ('fuzzifier', 'eta', 'sigma', 'init_method')
_fitted_array_names = ('U_', 'T_')
def __init__(self, n_clusters: int, fuzzifier: float = 2.0, eta: float = 2.0,
sigma: float = 1.0, max_iter: int = 100, epsilon: float = 1e-5,
init_method: str = 'kfcm', random_seed: int = 42, distance_fn=None,
patience: int | None = None, restore_best: bool = False,
plot_steps: bool = False, show_confidence: bool = True,
show_pca_variance: bool = True, save_plot_path: str | None = None,
callbacks=None):
super().__init__(
n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon,
random_seed=random_seed, distance_fn=distance_fn, patience=patience,
restore_best=restore_best, plot_steps=plot_steps,
show_confidence=show_confidence, show_pca_variance=show_pca_variance,
save_plot_path=save_plot_path, callbacks=callbacks,
)
if fuzzifier <= 1.0:
raise ValueError("fuzzifier must be > 1.0")
if eta <= 1.0:
raise ValueError("eta must be > 1.0")
if sigma <= 0:
raise ValueError("sigma must be > 0")
self.fuzzifier = fuzzifier
self.eta = eta
self.sigma = sigma
self.init_method = init_method
# Model-specific fitted attributes
self.U_ = None
self.T_ = None
def _initialize(self, X: chex.Array):
n_samples = X.shape[0]
if self.init_method == 'random':
alpha = jnp.ones(self.n_clusters)
U = jax.random.dirichlet(self.key, alpha, shape=(n_samples,))
T = jax.random.dirichlet(self.key, alpha, shape=(n_samples,))
indices = jax.random.choice(self.key, n_samples, shape=(self.n_clusters,), replace=False)
centroids = X[indices]
elif self.init_method == 'kfcm':
kfcm = KFCM(self.n_clusters, self.fuzzifier, self.sigma, self.max_iter,
self.epsilon, 'random', self.random_seed, False)
kfcm.fit(X)
U = kfcm.U_
centroids = kfcm.centroids_
alpha = jnp.ones(self.n_clusters)
T = jax.random.dirichlet(self.key, alpha, shape=(n_samples,))
else:
raise ValueError(f"Unknown init_method: {self.init_method}")
return U, T, centroids
@partial(jit, static_argnums=(0,))
def _compute_centroids(self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array) -> chex.Array:
K = batch_gaussian_kernel(X, centroids, self.sigma)
U_fuzz = jnp.power(U, self.fuzzifier)
T_fuzz = jnp.power(T, self.eta)
weights = (U_fuzz + T_fuzz) * K
numerator = weights.T @ X
denominator = jnp.sum(weights, axis=0, keepdims=True).T
return numerator / jnp.maximum(denominator, 1e-10)
@partial(jit, static_argnums=(0,))
def _update_fuzzy_matrix(self, X: chex.Array, centroids: chex.Array, fuzzifier: float) -> chex.Array:
K = batch_gaussian_kernel(X, centroids, self.sigma)
kernel_dist = 1.0 - K
kernel_dist = jnp.maximum(kernel_dist, 1e-10)
power = 1.0 / (fuzzifier - 1.0)
def compute_membership_row(distances_i):
ratios = distances_i[:, None] / distances_i[None, :]
powered_ratios = jnp.power(ratios, power)
denominators = jnp.sum(powered_ratios, axis=1)
return 1.0 / denominators
U = jax.vmap(compute_membership_row)(kernel_dist)
return U / jnp.sum(U, axis=1, keepdims=True)
@partial(jit, static_argnums=(0,))
def _update_typicality_matrix(self, X: chex.Array, centroids: chex.Array) -> chex.Array:
"""Update typicality matrix with column-sum-to-1 constraint (Pal et al. 1997)."""
K = batch_gaussian_kernel(X, centroids, self.sigma)
kernel_dist = 1.0 - K
kernel_dist = jnp.maximum(kernel_dist, 1e-10)
power = 1.0 / (self.eta - 1.0)
inv_dist_powered = jnp.power(1.0 / kernel_dist, power)
col_sums = jnp.maximum(jnp.sum(inv_dist_powered, axis=0, keepdims=True), 1e-10)
return inv_dist_powered / col_sums
@partial(jit, static_argnums=(0,))
def _compute_objective(self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array) -> chex.Array:
K = batch_gaussian_kernel(X, centroids, self.sigma)
kernel_dist = 2.0 * (1.0 - K)
U_fuzz = jnp.power(U, self.fuzzifier)
T_fuzz = jnp.power(T, self.eta)
weights = U_fuzz + T_fuzz
return jnp.sum(weights * kernel_dist)
@partial(jit, static_argnums=(0,))
def _iteration_step(self, state: KFPCMState, X: chex.Array) -> tuple[KFPCMState, dict]:
U_new = self._update_fuzzy_matrix(X, state.centroids, self.fuzzifier)
T_new = self._update_typicality_matrix(X, state.centroids)
centroids_new = self._compute_centroids(X, U_new, T_new, state.centroids)
objective = self._compute_objective(X, U_new, T_new, centroids_new)
centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro')
converged = centroid_change <= self.epsilon
new_state = KFPCMState(centroids_new, U_new, T_new, objective, state.iteration + 1, converged)
metrics = {
'objective': new_state.objective,
'centroid_change': centroid_change,
'converged': new_state.converged,
}
return new_state, metrics
def _build_info(self, state, iteration):
labels = jnp.argmax(state.U, axis=1)
weights = (jnp.max(state.U, axis=1) + jnp.max(state.T, axis=1)) / 2.0
return {
'centroids': state.centroids, 'labels': labels,
'weights': weights, 'iteration': iteration,
'objective': float(state.objective), 'max_iter': self.max_iter,
}
[docs]
def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self:
if resume and initial_centroids is not None:
raise ValueError("Cannot use both resume=True and initial_centroids")
X = self._validate_input(X)
if resume:
self._check_fitted()
centroids_init = self.centroids_
U_init = self._update_fuzzy_matrix(X, centroids_init, self.fuzzifier)
T_init = self._update_typicality_matrix(X, centroids_init)
elif initial_centroids is not None:
centroids_init = self._validate_initial_centroids(X, initial_centroids)
U_init = self._update_fuzzy_matrix(X, centroids_init, self.fuzzifier)
T_init = self._update_typicality_matrix(X, centroids_init)
else:
U_init, T_init, centroids_init = self._initialize(X)
initial_objective = self._compute_objective(X, U_init, T_init, centroids_init)
initial_state = KFPCMState(centroids_init, U_init, T_init, initial_objective, 0, False)
final_state, self.history_ = self._run_training(X, initial_state)
self.centroids_ = final_state.centroids
self.U_ = final_state.U
self.T_ = final_state.T
self.n_iter_ = int(final_state.iteration)
self.objective_ = float(final_state.objective)
self.objective_history_ = self.history_['objective']
return self
[docs]
def predict(self, X: chex.Array) -> chex.Array:
self._check_fitted()
U = self._update_fuzzy_matrix(X, self.centroids_, self.fuzzifier)
return jnp.argmax(U, axis=1)
[docs]
def predict_proba(self, X: chex.Array) -> chex.Array:
self._check_fitted()
X = jnp.asarray(X)
return self._update_fuzzy_matrix(X, self.centroids_, self.fuzzifier)
[docs]
def get_typicality(self, X: chex.Array) -> chex.Array:
self._check_fitted()
X = jnp.asarray(X)
return self._update_typicality_matrix(X, self.centroids_)