Source code for prosemble.models.kpfcm

"""
JAX-based Kernel Possibilistic Fuzzy C-Means (KPFCM) clustering implementation.
"""

from typing import NamedTuple, Self
from functools import partial
import jax
import jax.numpy as jnp
import chex
from jax import jit
from prosemble.core.kernel import batch_gaussian_kernel
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
from prosemble.models.kfcm import KFCM


class KPFCMState(NamedTuple):
    centroids: chex.Array
    U: chex.Array
    T: chex.Array
    gamma: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool


[docs] class KPFCM(ScanFitMixin, FuzzyClusteringBase): """Kernel Possibilistic Fuzzy C-Means with JAX. KPFCM combines fuzzy membership (:math:`U`) and typicality (:math:`T`) in kernel space with weights :math:`a` and :math:`b`. Parameters ---------- fuzzifier : float, default=2.0 Fuzzification parameter for membership (must be > 1.0). eta : float, default=2.0 Fuzzification parameter for typicality (must be > 1.0). a : float, default=1.0 Weight for fuzzy membership term (must be > 0). b : float, default=1.0 Weight for typicality term (must be > 0). k : float, default=1.0 Scaling parameter for :math:`\\gamma` (must be > 0). sigma : float, default=1.0 Kernel bandwidth parameter (must be > 0). init_method : {'kfcm'}, default='kfcm' Initialization method. n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. """ _hyperparams = ('fuzzifier', 'sigma', 'a', 'b', 'eta', 'k', 'init_method') _fitted_array_names = ('U_', 'T_', 'gamma_') def __init__(self, n_clusters: int, fuzzifier: float = 2.0, eta: float = 2.0, a: float = 1.0, b: float = 1.0, k: float = 1.0, sigma: float = 1.0, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'kfcm', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str | None = None, callbacks=None): # Validate model-specific parameters if fuzzifier <= 1.0: raise ValueError("fuzzifier must be > 1.0") if eta <= 1.0: raise ValueError("eta must be > 1.0") if a <= 0: raise ValueError("a must be > 0") if b <= 0: raise ValueError("b must be > 0") if k <= 0: raise ValueError("k must be > 0") if sigma <= 0: raise ValueError("sigma must be > 0") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.fuzzifier = fuzzifier self.eta = eta self.a = a self.b = b self.k = k self.sigma = sigma self.init_method = init_method # Model-specific fitted attributes self.U_ = None self.T_ = None self.gamma_ = None def _initialize(self, X: chex.Array): kfcm = KFCM(self.n_clusters, self.fuzzifier, self.sigma, self.max_iter, self.epsilon, 'random', self.random_seed, False) kfcm.fit(X) U = kfcm.U_ centroids = kfcm.centroids_ T = jnp.ones_like(U) / self.n_clusters return U, T, centroids @partial(jit, static_argnums=(0,)) def _compute_centroids(self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.eta) weights = (self.a * U_fuzz + self.b * T_fuzz) * K numerator = weights.T @ X denominator = jnp.sum(weights, axis=0, keepdims=True).T return numerator / jnp.maximum(denominator, 1e-10) @partial(jit, static_argnums=(0,)) def _update_U(self, X: chex.Array, centroids: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 1.0 - K kernel_dist = jnp.maximum(kernel_dist, 1e-10) power = 1.0 / (self.fuzzifier - 1.0) def compute_membership_row(distances_i): ratios = distances_i[:, None] / distances_i[None, :] powered_ratios = jnp.power(ratios, power) denominators = jnp.sum(powered_ratios, axis=1) return 1.0 / denominators U = jax.vmap(compute_membership_row)(kernel_dist) return U / jnp.sum(U, axis=1, keepdims=True) @partial(jit, static_argnums=(0,)) def _compute_gamma(self, X: chex.Array, U: chex.Array, centroids: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) U_fuzz = jnp.power(U, self.fuzzifier) numerator = jnp.sum(U_fuzz * kernel_dist, axis=0) denominator = jnp.sum(U_fuzz, axis=0) return self.k * numerator / denominator @partial(jit, static_argnums=(0,)) def _update_T(self, X: chex.Array, centroids: chex.Array, gamma: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) kernel_dist = jnp.maximum(kernel_dist, 1e-10) power = 1.0 / (self.eta - 1.0) ratio = self.b * kernel_dist / gamma[None, :] return 1.0 / (1.0 + jnp.power(ratio, power)) @partial(jit, static_argnums=(0,)) def _compute_objective(self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array, gamma: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.eta) weights = self.a * U_fuzz + self.b * T_fuzz term1 = jnp.sum(weights * kernel_dist) one_minus_T = 1.0 - T one_minus_T_fuzz = jnp.power(one_minus_T, self.eta) inner_sum = jnp.sum(one_minus_T_fuzz, axis=0) term2 = jnp.sum(gamma * inner_sum) return term1 + term2 @partial(jit, static_argnums=(0,)) def _iteration_step(self, state: KPFCMState, X: chex.Array) -> tuple[KPFCMState, dict]: U_new = self._update_U(X, state.centroids) T_new = self._update_T(X, state.centroids, state.gamma) centroids_new = self._compute_centroids(X, U_new, T_new, state.centroids) gamma_new = self._compute_gamma(X, U_new, centroids_new) objective = self._compute_objective(X, U_new, T_new, centroids_new, gamma_new) centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro') converged = centroid_change <= self.epsilon new_state = KPFCMState(centroids_new, U_new, T_new, gamma_new, objective, state.iteration + 1, converged) metrics = { 'objective': new_state.objective, 'centroid_change': centroid_change, 'converged': new_state.converged, } return new_state, metrics def _build_info(self, state, iteration): labels = jnp.argmax(state.U, axis=1) weights = (jnp.max(state.U, axis=1) + jnp.max(state.T, axis=1)) / 2.0 return { 'centroids': state.centroids, 'labels': labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, }
[docs] def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self: if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") X = self._validate_input(X) if resume: self._check_fitted() centroids_init = self.centroids_ U_init = self._update_U(X, centroids_init) gamma_init = self._compute_gamma(X, U_init, centroids_init) T_init = self._update_T(X, centroids_init, gamma_init) elif initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) U_init = self._update_U(X, centroids_init) gamma_init = self._compute_gamma(X, U_init, centroids_init) T_init = self._update_T(X, centroids_init, gamma_init) else: U_init, T_init, centroids_init = self._initialize(X) gamma_init = self._compute_gamma(X, U_init, centroids_init) initial_objective = self._compute_objective(X, U_init, T_init, centroids_init, gamma_init) initial_state = KPFCMState(centroids_init, U_init, T_init, gamma_init, initial_objective, 0, False) final_state, self.history_ = self._run_training(X, initial_state) self.centroids_ = final_state.centroids self.U_ = final_state.U self.T_ = final_state.T self.gamma_ = final_state.gamma self.n_iter_ = int(final_state.iteration) self.objective_ = float(final_state.objective) self.objective_history_ = self.history_['objective'] return self
[docs] def predict(self, X: chex.Array) -> chex.Array: self._check_fitted() U = self._update_U(X, self.centroids_) return jnp.argmax(U, axis=1)
[docs] def predict_proba(self, X: chex.Array) -> chex.Array: self._check_fitted() X = jnp.asarray(X) return self._update_U(X, self.centroids_)
[docs] def get_typicality(self, X: chex.Array) -> chex.Array: self._check_fitted() X = jnp.asarray(X) return self._update_T(X, self.centroids_, self.gamma_)