Source code for prosemble.models.kafcm

"""
JAX-based Kernel Adaptive Fuzzy C-Means (KAFCM) clustering implementation.

This module provides a GPU-accelerated implementation of KAFCM using JAX
with JIT compilation for high performance.
"""

from typing import NamedTuple, Self
from functools import partial

import jax
import jax.numpy as jnp
import chex
from jax import jit

from prosemble.core.kernel import batch_gaussian_kernel
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
from prosemble.models.kfcm import KFCM


class KAFCMState(NamedTuple):
    """Immutable state for KAFCM iteration."""
    centroids: chex.Array
    U: chex.Array
    T: chex.Array
    gamma: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool


[docs] class KAFCM(ScanFitMixin, FuzzyClusteringBase): """ Kernel Adaptive Fuzzy C-Means clustering with JAX. KAFCM extends AFCM to kernel space, combining fuzzy and possibilistic approaches with kernel-based distance measures. Kernel distance: :math:`d_K(x, v) = 2(1 - K(x, v))` Algorithm: 1. Initialize :math:`U` using KFCM 2. Compute :math:`\\gamma` parameters using kernel distance 3. Update :math:`T` using exponential kernel update 4. Update :math:`U` using standard KFCM rule 5. Update centroids (kernel-weighted with combined weights) 6. Repeat until convergence Parameters ---------- fuzzifier : float, default=2.0 Fuzziness parameter (must be > 1.0). a : float, default=1.0 Weight for fuzzy membership (must be > 0). b : float, default=1.0 Weight for typicality (must be > 0). k : float, default=1.0 Scaling parameter for :math:`\\gamma` (must be > 0). sigma : float, default=1.0 Kernel bandwidth parameter (must be > 0). init_method : {'kfcm'}, default='kfcm' Initialization method. n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. Examples -------- >>> import jax.numpy as jnp >>> from prosemble.models import KAFCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = KAFCM(n_clusters=2, sigma=1.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X) """ _hyperparams = ('fuzzifier', 'sigma', 'a', 'b', 'k', 'init_method') _fitted_array_names = ('U_', 'T_', 'gamma_') def __init__( self, n_clusters: int, fuzzifier: float = 2.0, a: float = 1.0, b: float = 1.0, k: float = 1.0, sigma: float = 1.0, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'kfcm', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str | None = None, callbacks=None, ): # Validate model-specific parameters if fuzzifier <= 1.0: raise ValueError("fuzzifier must be > 1.0") if a <= 0: raise ValueError("a must be > 0") if b <= 0: raise ValueError("b must be > 0") if k <= 0: raise ValueError("k must be > 0") if sigma <= 0: raise ValueError("sigma must be > 0") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.fuzzifier = fuzzifier self.a = a self.b = b self.k = k self.sigma = sigma self.init_method = init_method # Model-specific fitted attributes self.U_ = None self.T_ = None self.gamma_ = None def _initialize(self, X: chex.Array): """Initialize using KFCM.""" n_samples = X.shape[0] kfcm = KFCM( n_clusters=self.n_clusters, fuzzifier=self.fuzzifier, sigma=self.sigma, max_iter=self.max_iter, epsilon=self.epsilon, random_seed=self.random_seed, plot_steps=False ) kfcm.fit(X) U = kfcm.U_ centroids = kfcm.centroids_ T = jnp.zeros((n_samples, self.n_clusters)) return U, T, centroids @partial(jit, static_argnums=(0,)) def _compute_gamma( self, X: chex.Array, U: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute :math:`\\gamma` using kernel distance.""" K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) U_fuzz = jnp.power(U, self.fuzzifier) numerator = jnp.sum(U_fuzz * kernel_dist, axis=0) denominator = jnp.sum(U_fuzz, axis=0) gamma = self.k * numerator / denominator return gamma @partial(jit, static_argnums=(0,)) def _update_T( self, X: chex.Array, centroids: chex.Array, gamma: chex.Array ) -> chex.Array: """Update T: :math:`t_{ij} = \\exp(-b \\cdot d_K(x_i, v_j) / \\gamma_j)`""" K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) kernel_dist = jnp.maximum(kernel_dist, 1e-10) ratio = self.b * kernel_dist / gamma[None, :] T = jnp.exp(-ratio) return T @partial(jit, static_argnums=(0,)) def _update_U(self, X: chex.Array, centroids: chex.Array) -> chex.Array: """Update :math:`U` using kernel distance.""" K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 1.0 - K kernel_dist = jnp.maximum(kernel_dist, 1e-10) power = 1.0 / (self.fuzzifier - 1.0) def compute_membership_row(distances_i): ratios = distances_i[:, None] / distances_i[None, :] powered_ratios = jnp.power(ratios, power) denominators = jnp.sum(powered_ratios, axis=1) memberships = 1.0 / denominators return memberships U = jax.vmap(compute_membership_row)(kernel_dist) U = U / jnp.sum(U, axis=1, keepdims=True) return U @partial(jit, static_argnums=(0,)) def _compute_centroids( self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute centroids: kernel-weighted with :math:`a \\cdot U^m + b \\cdot T`.""" K = batch_gaussian_kernel(X, centroids, self.sigma) U_fuzz = jnp.power(U, self.fuzzifier) weights = (self.a * U_fuzz + self.b * T) * K numerator = weights.T @ X denominator = jnp.sum(weights, axis=0, keepdims=True).T denominator = jnp.maximum(denominator, 1e-10) centroids_new = numerator / denominator return centroids_new @partial(jit, static_argnums=(0,)) def _compute_objective( self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array, gamma: chex.Array ) -> chex.Array: """Compute KAFCM objective.""" K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) U_fuzz = jnp.power(U, self.fuzzifier) weights = self.a * U_fuzz + self.b * T term1 = jnp.sum(kernel_dist * weights) T_safe = jnp.maximum(T, 1e-10) entropy_like = T * jnp.log(T_safe) - T inner_sum = jnp.sum(entropy_like, axis=0) term2 = jnp.sum(gamma * inner_sum) objective = term1 + term2 return objective @partial(jit, static_argnums=(0,)) def _iteration_step(self, state: KAFCMState, X: chex.Array) -> tuple[KAFCMState, dict]: """Single iteration.""" T_new = self._update_T(X, state.centroids, state.gamma) U_new = self._update_U(X, state.centroids) centroids_new = self._compute_centroids(X, U_new, T_new, state.centroids) gamma_new = self._compute_gamma(X, U_new, centroids_new) objective = self._compute_objective(X, U_new, T_new, centroids_new, gamma_new) centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro') converged = centroid_change <= self.epsilon new_state = KAFCMState(centroids_new, U_new, T_new, gamma_new, objective, state.iteration + 1, converged) metrics = { 'objective': new_state.objective, 'centroid_change': centroid_change, 'converged': new_state.converged, } return new_state, metrics def _build_info(self, state, iteration): labels = jnp.argmax(state.U, axis=1) weights = jnp.max(state.U * state.T, axis=1) return { 'centroids': state.centroids, 'labels': labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, }
[docs] def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self: """Fit model.""" if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") X = self._validate_input(X) if resume: self._check_fitted() centroids_init = self.centroids_ U_init = self._update_U(X, centroids_init) gamma_init = self._compute_gamma(X, U_init, centroids_init) T_init = self._update_T(X, centroids_init, gamma_init) elif initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) U_init = self._update_U(X, centroids_init) gamma_init = self._compute_gamma(X, U_init, centroids_init) T_init = self._update_T(X, centroids_init, gamma_init) else: U_init, T_init, centroids_init = self._initialize(X) gamma_init = self._compute_gamma(X, U_init, centroids_init) initial_objective = self._compute_objective(X, U_init, T_init, centroids_init, gamma_init) initial_state = KAFCMState(centroids_init, U_init, T_init, gamma_init, initial_objective, 0, False) final_state, self.history_ = self._run_training(X, initial_state) self.centroids_ = final_state.centroids self.U_ = final_state.U self.T_ = final_state.T self.gamma_ = final_state.gamma self.n_iter_ = int(final_state.iteration) self.objective_ = float(final_state.objective) self.objective_history_ = self.history_['objective'] return self
[docs] def predict(self, X: chex.Array) -> chex.Array: """Predict labels.""" self._check_fitted() U = self._update_U(X, self.centroids_) return jnp.argmax(U, axis=1)
[docs] def predict_proba(self, X: chex.Array) -> chex.Array: """Predict probabilities.""" self._check_fitted() X = jnp.asarray(X) return self._update_U(X, self.centroids_)
def get_typicality(self, X: chex.Array) -> chex.Array: """Get typicality.""" self._check_fitted() X = jnp.asarray(X) return self._update_T(X, self.centroids_, self.gamma_)